<a href="https://colab.research.google.com/github/Renass/Yasuo_project/blob/transfer_res_net/classification/classify.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import os
from google.colab import files
import torchvision
import glob
from torchvision import transforms
import numpy as np
import torch
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
import matplotlib
from matplotlib import pyplot as plt
import zipfile

#Data preproceccing

In [2]:
from google.colab import drive
drive.mount('/content/drive')


zip_path = '/content/drive/MyDrive/Colab_Notebooks/datasets/stm-data/stm-data.zip'
z = zipfile.ZipFile(zip_path, 'r')
z.extractall()
print(os.listdir())


path = '/content/stm-data/JointDataset'
train_bad_dir = path + '/Train/bad'
train_good_dir = path + '/Train/good'
test_bad_dir = path + '/Test/bad'
test_good_dir = path + '/Test/good'

train_bad_names = glob.glob(train_bad_dir + '/*.npy')
train_good_names = glob.glob(train_good_dir + '/*.npy')
test_bad_names = glob.glob(test_bad_dir + '/*.npy')
test_good_names = glob.glob(test_good_dir + '/*.npy')
all_names = [train_bad_names, train_good_names, test_bad_names, test_good_names]

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
['.config', 'stm-data', 'drive', 'sample_data']


#Preparing batches

In [3]:
class MyDataset(torch.utils.data.Dataset):
    def __init__(self, names, y, transform = None):
        self.names = names
        self.y = y
        if transform is None:
            self.should_transform = False
        else:
            self.transform = transform
            self.should_transform = True

    def __len__(self):
        return len(self.names)

    def __getitem__(self,idx):
        self.mas1 = np.load(self.names[idx])
        self.mas = np.zeros((3,64,64))
  
        for i in range(3):
            self.mas[i] = self.mas1.copy()
        self.mas = torch.from_numpy(self.mas)
        if self.should_transform:
            self.mas_transformed = self.transform(self.mas)
        else:
            self.mas_transformed = self.mas
        return self.mas_transformed.float(), self.y

In [4]:
transform = torchvision.transforms.Compose([                                   
    transforms.Normalize([0.45],[0.2])                                       
    ])


#bad: y=0
#good: y=1
train_bad_dataset = MyDataset(names=train_bad_names, y=float(0), transform=transform)
train_good_dataset = MyDataset(names=train_good_names, y=float(1), transform=transform)
train_dataset = train_bad_dataset + train_good_dataset

test_bad_dataset = MyDataset(names=test_bad_names, y=float(0), transform=transform)
test_good_dataset = MyDataset(names=test_good_names, y=float(1), transform=transform)
test_dataset = test_bad_dataset + test_good_dataset


train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=10, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=10, shuffle=True)
#torch.Size([batch_size,channels,64,64])

#Train-Test procedures

In [5]:
def my_ln(x):
  return torch.log(x+10e-45)

def binary_cross_entropy_weighted(pred, target):
    return torch.mean(-2*target*my_ln(pred)-(1-target)*my_ln(1-pred))





def train(model, train_loader, epochs, optimizer):
    model.train()
    loss_epochs = []
    for idx in range(epochs):
        loss_samples = []
        for data,target in train_loader:
            data=data.to(device)
            target=target.to(device)
            target = target.unsqueeze(1)
            optimizer.zero_grad()   # zero the gradient buffers
            output = model.forward(data)
            output = output.double()


            #loss = binary_cross_entropy_weighted(output, target)
            
            loss.backward()
            loss_samples.append(loss.data.cpu().numpy())
            optimizer.step()    # Does the update

        loss_samples_mean = float(sum(loss_samples)) / len (loss_samples)
        print(f"Epoch {idx: >8} Loss: {loss_samples_mean}")
        loss_epochs.append(loss_samples_mean)

    plt.plot(loss_epochs)
    plt.ylabel('Loss')
    plt.xlabel('Epoch')
    plt.show() 

In [6]:
def test(model, device, test_loader):
    model.eval()
    loss=0
    accuracy = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in test_loader:

            
            data, target = data.to(device), target.to(device)
            output = model.forward(data)
            output = output.double()
            target = target.unsqueeze(1)

            loss+= F.cross_entropy(output,target)
            #loss+= binary_cross_entropy_weighted(output, target)
            pred = target.clone()
            for i,x in enumerate(output):
                ans = 0.
                if x > 0.5:
                    ans = 1.
                pred[i] = ans 


            for i,single_pred in enumerate(pred):
                if single_pred == target[i]:
                    correct+= 1
            total += len(pred)
    loss = loss/len(test_loader)
    accuracy = correct / total
    
    print('Loss: ',loss.data.item())
    print('\n Accuracy: {}/{} ({:.0f}%)\n'.format(
        correct, total, 100. * accuracy))

#Main loop

In [7]:
no_cuda = False
use_cuda = not no_cuda and torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
print(device)



model = torchvision.models.resnet18(pretrained = True).to(device)

for param in model.parameters():
    param.requires_grad = False

model.fc = torch.nn.Linear(model.fc.in_features, 2).to(device)

loss = torch.nn.CrossEntropyLoss()
optimizer=optim.Adam(model.parameters())
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

cpu


In [8]:
#Main loop train epochs_train times and test once and
#repeat it train_cycles times
train_cycles=3
epochs_train=10
lr=10e-4
#Main loop
for i in range(train_cycles):
    train(model=model, train_loader=train_loader, epochs=epochs_train, optimizer=optim.Adam(model.parameters(), lr=lr))
    test(model=model,device=device,test_loader=test_loader)

KeyboardInterrupt: ignored