In [1]:
import torch
import torch.optim as optim
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable

import torchvision
import torchvision.transforms as transforms
import torchvision.models as models

from tensorboardX import SummaryWriter
from MLDataTools.image_normalization import RandomDihedral
from skimage.external import tifffile as tiff
import numpy as np
from matplotlib import pyplot

%matplotlib inline

# torch setup
torch.set_default_tensor_type('torch.DoubleTensor')
# torch.set_default_tensor_type(torch.DoubleTensor) # so it doesnt throw a incompatible type exception
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # important for cloud compatability

In [2]:
# tensorboard setup
from datetime import datetime
date = datetime.now().strftime("%Y-%m-%d.%H:%M:%S")

writer = SummaryWriter('tensorboardx/ResNet50_'+date)

## Model Definition

In [3]:
net = models.resnet18()

In [4]:
in_features = net.fc.in_features
net.fc = nn.Linear(in_features, 4) # 4 = num classes

## Data Loading and Pre-processing

In [5]:
class ToImage:
    def __call__(self, sample):
        sample = torch.from_numpy(sample)
        zeros = torch.zeros(1,200,200)
        return torch.cat((sample, zeros))
            


In [6]:
ds_transforms = transforms.Compose([
    ToImage(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
    RandomDihedral(),
])

In [7]:
# lets us control the datatype when the image is being read
def tiff_read(path:str):
    image = tiff.imread(path).astype(np.float)
    return image

In [8]:
BATCH_SIZE = 40
DATA_ROOT = '/home/user/HDev Dropbox/Projects/YNet_ready_data/yeast_v4'

trainset = torchvision.datasets.ImageFolder(DATA_ROOT+'/train', transform=ds_transforms, loader=tiff_read)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)

testset = torchvision.datasets.ImageFolder(DATA_ROOT+'/test', transform=ds_transforms, loader=tiff_read)
testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

# invert class_to_id
idx_to_class = {v:k for k,v in trainset.class_to_idx.items()}

## Training and Testing procedures


In [9]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.01)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')

In [10]:
def train(epoch):
    net.train() # affects only modules like Dropout
    trainiter = iter(trainloader)
    for batch_idx, (data, targets) in enumerate(trainiter, 0):
        # get the inputs

        data, targets = Variable(data), Variable(targets)
        l = data.size(0)

        # backprop
        optimizer.zero_grad() # dont forget to do that
        output = net(data)
        loss = criterion(output, targets)
        loss.backward()
        optimizer.step()
        scheduler.step(loss)
        
        # tensorboard
        global global_step
        global_step += 1
        writer.add_scalar('Train_Loss', loss.data.item(), global_step)
        if batch_idx % 2 == 0: # every 2nd batch add to our embedding writer
            targets = targets.type(torch.DoubleTensor)
            writer.add_embedding(output, metadata=targets.data, label_img=data.data, global_step=global_step)
        if batch_idx % 5 == 0:
            samples_done = batch_idx * BATCH_SIZE
            percent = 100. * samples_donne / len(trainloader.dataset)
            print(f"Train Epoch: {epoch} [{samples_done}/{len(trainloader.dataset)} "
                  f"({percent:8.5}%)]\tLoss: {loss.item():10.5}")
        

In [11]:
def test(epoch):
    with torch.no_grad():
        net.eval()
        test_loss = 0
        wcases = [] # list of worst cases
        classes_correct = list(0 for i in range(NUM_CLASSES))
        classes_total = list(0 for i in range(NUM_CLASSES))
        
        for data, targets in iter(testloader):
            # NOTE: UPDATE THIS LINE WHEN TORCH VERSION == 0.4
            data, targets = Variable(data), Variable(targets)
            
            output = net(data)
            # errors is a top 2 List[error:float, index of sample:int]
            errors = worst_cases(output, targets)  
            
            for error, idx in errors:
                target = targets[idx].data.item()
                label = f"{idx_to_class[target]}_{error:10.5}"
                tensor = data[idx].numpy().copy()
                wcases.append((tensor,label))
            
            # sum up batch loss
            test_loss += criterion(output, targets).data.item()
            
            
            # get the index of the max log-probability
            _, pred = output.max(1) # returns a tuple the last element is the index Tensor
            c = (pred == targets).squeeze()
            l = c.size(0) # to account for different batch sizes
            # this helps to identify which classes the network is struggling with
            for i in range(l):
                label = targets.data[i].item()
                classes_correct[label] += c.data[i].item()
                classes_total[label] += 1
                        
        writer.add_scalar('Test_Loss', test_loss, epoch)
        test_loss /= len(testloader.dataset)
        accuracy = 100. * (sum(classes_correct) / sum(classes_total))
        print(f"\nTest set: Average loss: {test_loss:6.5}, Accuracy: {accuracy:10.5}\n")
        
        for i, total, correct in zip(range(NUM_CLASSES), classes_total, classes_correct):
            cl = idx_to_class[i]
            cl_accuracy = 100. * (classes_correct[i] / classes_total[i])
            print(f"class [{cl}]: accuracy {cl_accuracy:10.4}%")
        print()# prints a newline
        
#         for image, label in wcases[:5]: # List[image:np.array, class:str]
#             # should be of dimensions (2, 200, 200)
#             tiff.imshow(image, title=label)

In [12]:
# helps to identify which cases the network was really wrong about
def worst_cases(output: torch.Tensor, targets: torch.Tensor, top=2):
    assert output.size(0) == targets.size(0)
    length = output.size(0)
    errors = []
    for i in range(length):
        z = torch.zeros(NUM_CLASSES)
        label = targets[i].item()
        z[label] = 1
        diff = (output[i] - z).numpy().copy()
        diff = np.sum(np.abs(diff))
        errors.append((diff, i))
        
    errors.sort(key=lambda x: x[0], reverse=True)
        
    return errors[:top]

In [13]:
EPOCHS = 10
global_step = 0
NUM_CLASSES = 4

for i in range(EPOCHS):
    train(i)
    test(i)
    
print("\n Finished training.")

NameError: name 'samples_donne' is not defined