In [1]:
import torch
import torchvision
import os
import numpy as np
import matplotlib.pyplot as plt

In [2]:
torch.cuda.is_available()

True

In [3]:
USE_GPU = True
num_class = 100
dtype = torch.float32 # we will be using float throughout this tutorial
device = torch.device("cpu")
try:
    if USE_GPU and torch.cuda.is_available():
        device = torch.device('cuda')
    elif torch.backends.mps.is_available():
        device = torch.device("mps")
    else:
        device = torch.device('cpu')
except:
    if USE_GPU and torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

# Constant to control how frequently we print train loss
print_every = 100
print('using device:', device)

using device: cuda


In [4]:
def plotImage(image):
    return plt.imshow(image.T)

In [5]:
BASE = "lung_colon_image_set/"

In [6]:
class LungColonDS(torchvision.datasets.DatasetFolder):
    def find_classes(self, directory: str):
        classes = ["colon_image_sets/colon_aca","colon_image_sets/colon_n","lung_image_sets/lung_aca","lung_image_sets/lung_n","lung_image_sets/lung_scc"]
        return classes, {val:ind for ind, val in enumerate(classes)}

In [7]:
lung_colon_dataset = LungColonDS(BASE,torchvision.io.read_image, extensions = tuple([".jpeg"]))

# Making Dataloader
Dataloaders is how we will present the data to the trainer, so we need to create this object. In the future, these can be adjusted to handle any kind of new distribution we want to make

In [8]:
from torch.utils.data import DataLoader

train_count = int(0.7*len(lung_colon_dataset))
valid_count = len(lung_colon_dataset) - train_count
train_dataset, valid_dataset = torch.utils.data.random_split(lung_colon_dataset, [train_count, valid_count])

train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_dataloader = DataLoader(valid_dataset, batch_size=64, shuffle=True)

In [9]:
imgs, labels = next(iter(train_dataloader))

In [10]:
labels.shape

torch.Size([64])

In [11]:
imgs[0].shape

torch.Size([3, 768, 768])

# Training Resnet With Rotated Images

In [12]:
import torch
import torchvision
import os
import numpy as np
import matplotlib.pyplot as plt

In [13]:
BASE = "lung_colon_image_set/"

In [14]:
class LungColonDS(torchvision.datasets.DatasetFolder):
    def find_classes(self, directory: str):
        classes = ["colon_image_sets/colon_aca","colon_image_sets/colon_n","lung_image_sets/lung_aca","lung_image_sets/lung_n","lung_image_sets/lung_scc"]
        return classes, {val:ind for ind, val in enumerate(classes)}

In [15]:
lung_colon_dataset = LungColonDS(BASE,torchvision.io.read_image, extensions = (".jpeg",))
import RotatedDataset
lung_colon_rotated_dataset = RotatedDataset.RotatedDataset(lung_colon_dataset, use_both_labels = True)

In [16]:
lung_colon_dataset.classes

['colon_image_sets/colon_aca',
 'colon_image_sets/colon_n',
 'lung_image_sets/lung_aca',
 'lung_image_sets/lung_n',
 'lung_image_sets/lung_scc']

In [17]:
train_count = int(0.7*len(lung_colon_rotated_dataset))
valid_count = len(lung_colon_rotated_dataset) - train_count
train_dataset, valid_dataset = torch.utils.data.random_split(lung_colon_rotated_dataset, [train_count, valid_count])

In [18]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_dataloader = DataLoader(valid_dataset, batch_size=64, shuffle=True)

In [19]:
import resnet
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

In [20]:
learning_rate = 2.5e-4
model = resnet.ResNet.TTT_Implementation(5,4,True)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [27]:
def check_accuracy_TTT(loader, model):
    #Only check accuracy of main task
    num_correct = 0
    num_samples = 0
    model.eval()  # set model to evaluation mode
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device=device, dtype=dtype)  # move to device, e.g. GPU
            y = y[0]
            y = y.to(device=device, dtype=torch.long)
            scores = model(x)[0]
            _, preds = scores.max(1)
            num_correct += (preds == y).sum()
            num_samples += preds.size(0)
        acc = float(num_correct) / num_samples if float(num_correct) != 0 else 0
        print('Got %d / %d correct (%.2f)' % (num_correct, num_samples, 100 * acc))
    return acc

def trainTTT(model, optimizer, epochs=1):
    """
    Train a TTT model.
    
    Inputs:
    - model: A PyTorch Module giving the model to train.
    - optimizer: An Optimizer object we will use to train the model
    - epochs: (Optional) A Python integer giving the number of epochs to train for
    
    Returns: The accuracy of the model
    """
    model = model.to(device=device)  # move the model parameters to CPU/GPU
    for e in range(epochs):
        for t, (x, y) in enumerate(train_dataloader):
            model.train()  # put model to training mode
            x = x.to(device=device, dtype=dtype)  # move to device, e.g. GPU
            y0,y1 = y
            y0 = y0.to(device=device, dtype=torch.long)
            y1 = y1.to(device=device, dtype=torch.long)

            scores = model(x)
            loss = F.cross_entropy(scores[0], y0) + F.cross_entropy(scores[1], y1)

            # Zero out all of the gradients for the variables which the optimizer
            # will update.
            optimizer.zero_grad()

            # This is the backwards pass: compute the gradient of the loss with
            # respect to each  parameter of the model.
            loss.backward()

            # Actually update the parameters of the model using the gradients
            # computed by the backwards pass.
            optimizer.step()

            if (t + 1) % print_every == 0:
                print('Epoch %d, Iteration %d, loss = %.4f' % (e, t + 1, loss.item()))
                check_accuracy_TTT(test_dataloader, model)
                print()
            if t == 0:
                print(f"Passed first iteration! Loss: {loss.item()}")
    return check_accuracy_TTT(test_dataloader, model)

In [28]:
train_dataloader

<torch.utils.data.dataloader.DataLoader at 0x7f06503d7a90>

In [29]:
print_every = 100
trainTTT(model, optimizer, epochs=10)
print_every = 100

Passed first iteration! Loss: 1.7062921524047852
Epoch 0, Iteration 100, loss = 1.7001
Got 5554 / 7500 correct (74.05)

Epoch 0, Iteration 200, loss = 1.6517
Got 6693 / 7500 correct (89.24)

Passed first iteration! Loss: 1.7489622831344604
Epoch 1, Iteration 100, loss = 1.5911
Got 5718 / 7500 correct (76.24)

Epoch 1, Iteration 200, loss = 1.5455
Got 5684 / 7500 correct (75.79)

Passed first iteration! Loss: 1.529234766960144
Epoch 2, Iteration 100, loss = 1.6387
Got 6243 / 7500 correct (83.24)

Epoch 2, Iteration 200, loss = 1.5964
Got 5688 / 7500 correct (75.84)

Passed first iteration! Loss: 1.4618902206420898
Epoch 3, Iteration 100, loss = 1.5363
Got 6177 / 7500 correct (82.36)

Epoch 3, Iteration 200, loss = 1.6065
Got 6456 / 7500 correct (86.08)

Passed first iteration! Loss: 1.4556032419204712
Epoch 4, Iteration 100, loss = 1.5090
Got 5731 / 7500 correct (76.41)

Epoch 4, Iteration 200, loss = 1.5874
Got 3114 / 7500 correct (41.52)

Passed first iteration! Loss: 1.35712170600891