In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torch.nn.functional as F
from torchvision import transforms
import matplotlib.pyplot as plt
import os
from tqdm import tqdm
from torch.utils.data import Dataset
import random
import torch.utils.data as data_utils
from PIL import Image
import wandb

PyTorch Version:  2.0.0+cu117
Torchvision Version:  0.15.1+cu117


In [2]:
# Number of classes in the dataset
num_classes = 2
# Batch size for training (change depending on how much memory you have)
batch_size = 32
# Number of epochs to train for
num_epochs = 10

### Helper functions

In [3]:
def train_model(model, dataloaders, progress, criterion, optimizer, num_epochs=25):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch+1, num_epochs))
        print('-' * 10)
        
        for phase in ['train', 'val']:
            progress[phase].reset()
            if phase == 'train':
                model.train()
            else:
                model.eval()
                
            running_loss = 0.0
            running_corrects = 0
            
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)
                
                progress[phase].update()
                
                optimizer.zero_grad()
                
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)
                    
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                        
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
                
            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
            
            if phase == 'train':
                loss_train = epoch_loss
                acc_train = epoch_acc
            if phase == 'val':
                loss_valid = epoch_loss
                acc_valid = epoch_acc
            print(f'Epoch {epoch + 1}/{num_epochs}, {phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
        
        print()
        # Log the loss and accuracy values at the end of each epoch
        wandb.log({
            "Epoch": epoch,
            "Train Loss": loss_train,
            "Train Acc": acc_train,
            "Valid Loss": loss_valid,
            "Valid Acc": acc_valid})      
            
    return model

In [None]:
# Create Inception module for model
class InceptionBlock(nn.Module):
    def __init__(self, in_channels, output_1x1, output_5x5, output_3x3, output_pool):
        super(InceptionBlock, self).__init__()
        
        self.branch1x1 = nn.Conv2d(in_channels, output_1x1, kernel_size=1)
        
        self.branch5x5 = nn.Sequential(
            nn.Conv2d(in_channels, output_5x5[0], kernel_size=1),
            nn.Conv2d(output_5x5[0], output_5x5[1], kernel_size=5, padding=2)
        )
        
        self.branch3x3dbl = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=1),
            nn.Conv2d(64, 96, kernel_size=3, padding=1),
            nn.Conv2d(96, 96, kernel_size=3, padding=1)
        )
        
        self.branch_pool = nn.Sequential(
            nn.AvgPool2d(kernel_size=3, stride=1, padding=1),
            nn.Conv2d(in_channels, 32, kernel_size=1)
        )
        
    def forward(self, x):
        branch1x1 = self.branch1x1(x)
        branch5x5 = self.branch5x5(x)
        branch3x3dbl = self.branch3x3dbl(x)
        branch_pool = self.branch_pool(x)
        
        out = torch.cat((branch1x1, branch5x5, branch3x3dbl, branch_pool), dim=1)
        
        return out

In [None]:
# Define Inception model class

class InceptionV3(nn.Module):
    def __init__(self, num_classes, is_training=True, dropout_keep_prob=0.8):
        super(InceptionV3, self).__init__()
        
        self.dropout_keep_prob = dropout_keep_prob
        self.is_training = is_training

        self.conv0 = nn.Conv2d(3, 32, kernel_size=3, stride=2)
        self.conv1 = nn.Conv2d(32, 32, kernel_size=3)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool1 = nn.MaxPool2d(kernel_size=3, stride=2)
        self.conv3 = nn.Conv2d(64, 80, kernel_size=1)
        self.conv4 = nn.Conv2d(80, 192, kernel_size=3)
        self.pool2 = nn.MaxPool2d(kernel_size=3, stride=2)

        # Inception blocks
        self.inception1 = InceptionBlock(192)
        # self.inception2 = InceptionBlock(256)
        # self.inception3 = InceptionBlock(288)
        # add more !!

        # Auxiliary Head logits
        self.aux_logits = nn.Sequential(
            nn.AvgPool2d(kernel_size=5, stride=3),
            nn.Conv2d(288, 128, kernel_size=1),
            nn.Conv2d(128, 768, kernel_size=shape[1:3]),
            nn.Flatten(),
            nn.Linear(768, num_classes)
        )
        # Final pooling and prediction
        self.logits = nn.Sequential(
            nn.AvgPool2d(kernel_size=shape[1:3]),
            nn.Dropout(dropout_keep_prob),
            nn.Flatten(),
            nn.Linear(2048, num_classes)
        )

    def forward(self, inputs):
        # Option 1 (no inception blocks yet)
        # conv0 = self.conv0(inputs)
        # conv1 = self.conv1(conv0)
        # conv2 = self.conv2(conv1)
        # pool1 = self.pool1(conv2)
        # conv3 = self.conv3(pool1)
        # conv4 = self.conv4(conv3)
        # pool2 = self.pool2(conv4)

        # Option 2
        x = F.relu(self.conv0(x))
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = self.pool1(x)
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = self.pool2(x)

        # Apply Inception blocks
        x = self.inception1(x)
        # x = self.inception2(x)
        # x = self.inception3(x)

        aux_logits = self.aux_logits(x)
        logits = self.logits(x)

        if self.is_training:
            if self.restore_logits:
                predictions = torch.nn.functional.softmax(logits, dim=1)
                print("softmax training")
            else:
                predictions = torch.nn.functional.sigmoid(logits)
                print("sigmoid training")
        else:
            predictions = torch.nn.functional.softmax(logits, dim=1)
            print("softmax default training")

        return logits, predictions

        # return pool2
        # return x

In [4]:
def initialise_model(num_classes):
    # Define the model architecture

    return model

In [5]:
# Split image folders into train, val, test
def split_data(patch_directory, split: list, seed):
    '''
    Function that takes in the split percentage for train/val/test sets, and randomly chooses which cases
    to allocate to which set (to ensure all patches from one case go into one set)
    Parameters:
    patch_directory: folder containing all patches
    split: list of integers for splitting sets
    seed: option to set the seed value for randomness
    Returns:
    3 lists for each of train/val/test, where each list contains the case names to be used in the set
    '''
    
    random.seed(seed)

    case_folders = os.listdir(patch_directory) # get 147 case folders
    
    d = {}
    for folder in case_folders:
        num_patches_in_folder = len(os.listdir(patch_directory + folder))
        d[folder] = num_patches_in_folder
    
    total_num_patches = sum(d.values())
    train_split, val_split, test_split = split
    train_num_patches = int((train_split/100)*total_num_patches)
    val_num_patches = int((val_split/100)*total_num_patches)

    # list all folders in the directory
    folders = [os.path.join(patch_directory, folder) for folder in os.listdir(patch_directory) if os.path.isdir(os.path.join(patch_directory, folder))]
    
    # SELECT TRAINING CASES
    train_cases = [] # store all selected cases
    num_selected_train = 0 # number of patches selected so far
    selected_folders = set() # a set to store the selected folder names to keep track of those already selected
    while num_selected_train < train_num_patches:
        folder = random.choice(folders)
        if folder not in selected_folders:
            case = folder.replace(patch_directory, '')
            num_patches = len(os.listdir(folder))
            num_selected_train += num_patches
            selected_folders.add(folder) # add to set of selected folders
            train_cases.append(case)

    # SELECT VAL CASES
    val_cases = [] # store all selected cases
    num_selected_val = 0 # number of patches selected so far
    while num_selected_val < val_num_patches:
        folder = random.choice(folders)
        if folder not in selected_folders:
            case = folder.replace(patch_directory, '')
            num_patches = len(os.listdir(folder))
            num_selected_val += num_patches
            selected_folders.add(folder)
            val_cases.append(case)

    # SELECT TEST CASES
    cases = [folder.replace(patch_directory, '') for folder in folders]
    used = train_cases+val_cases
    test_cases = [case for case in cases if case not in used]
    
    # test_patches = [len(os.listdir(patch_directory + folder)) for folder in test_cases]
    num_selected_test = sum([len(os.listdir(patch_directory + folder)) for folder in test_cases])
    # dict = {x: for x in ['train', 'val', 'test']}
    print(f"Number of training patches: {num_selected_train} \nNumber of validation patches {num_selected_val} \nNumber of test patches {num_selected_test}")
    return train_cases, val_cases, test_cases

In [6]:
# Create a custom PyTorch dataset to read in your images and apply transforms

class CustomDataset(Dataset):
    def __init__(self, img_folders, label_files, transform=None):
        self.img_folders = img_folders
        self.label_files = label_files
        self.transform = transform

        self.imgs = [] # Keeps image paths to load in the __getitem__ method
        self.labels = []

        # Load images and corresponding labels
        for i, (img_folder, label_file) in enumerate(zip(img_folders, label_files)):
            # print("Patch directory", img_folder, "\nLabel file", label_file)
            labels_pt = torch.load(label_file) # Load .pt file
            # Run through all patches from the case folder
            for i, img in enumerate(os.listdir(img_folder)):
                if os.path.isfile(img_folder + '/' + img) and os.path.isfile(label_file):
                    # print(img_folder + img)
                    if img.startswith('._'):
                        img = img.replace('._', '')
                    idx = int(img.replace('.png', '').split("_")[1])
                    self.imgs.append(img_folder + '/' + img)
                    self.labels.append(labels_pt[idx].item()) # get label as int
        
    def __len__(self):
        return len(self.imgs)
    
    def __getitem__(self, idx):
        # Load image at given index
        image_path = self.imgs[idx]
        image = Image.open(image_path).convert('RGB')
        
        if self.transform is not None: # Apply transformations
            image = self.transform(image)
        
        label = self.labels[idx] # Load corresponding image label
        
        return image, label # Return transformed image and label

### Initialise simple CNN model

In [7]:
# Initialize the model for this run
CNN_model = initialise_model(num_classes)

# Print the model we just instantiated
print(CNN_model)

Sequential(
  (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): ReLU()
  (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (3): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (4): ReLU()
  (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (6): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (7): ReLU()
  (8): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (9): Flatten(start_dim=1, end_dim=-1)
  (10): Linear(in_features=65536, out_features=512, bias=True)
  (11): ReLU()
  (12): Linear(in_features=512, out_features=2, bias=True)
)


In [8]:
# from torchsummary import summary
# summary(CNN_model)

### Load data

In [9]:
PATCH_SIZE=256
STRIDE=PATCH_SIZE
SEED=42

In [10]:
# Initialise data transforms
    
data_transforms = {
    'train': transforms.Compose([
        # transforms.Resize(INPUT_SIZE),
        # transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # Initially no colour normalisation
    ]),
    'val': transforms.Compose([
        # transforms.Resize(INPUT_SIZE),
        transforms.ToTensor(),
        # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'test' : transforms.Compose([
        # transforms.Resize(INPUT_SIZE),
        transforms.ToTensor(),
        # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

In [11]:
# using full set of data

img_dir = '../data/patches/'
labels_dir = '../data/labels/'

split=[70, 15, 15] # for splitting into train/val/test

train_cases, val_cases, test_cases = split_data(img_dir, split, SEED)

train_img_folders = [img_dir + case for case in train_cases]
val_img_folders = [img_dir + case for case in val_cases]
test_img_folders = [img_dir + case for case in test_cases]

# Contains the file path for each .pt file for the cases used in each of the sets
train_labels = [labels_dir + case + '.pt' for case in train_cases]
val_labels = [labels_dir + case + '.pt' for case in val_cases]
test_labels = [labels_dir + case + '.pt' for case in test_cases]

image_datasets = {
    'train': CustomDataset(train_img_folders, train_labels, transform=data_transforms['train']),
    'val': CustomDataset(val_img_folders, val_labels, transform=data_transforms['val']),
    'test': CustomDataset(test_img_folders, test_labels, transform=data_transforms['test'])
}
# Create training, validation and test dataloaders
dataloaders = {
    'train': data_utils.DataLoader(image_datasets['train'], batch_size=batch_size, num_workers=4, shuffle=True, drop_last=True),
    'val': data_utils.DataLoader(image_datasets['val'], batch_size=batch_size, num_workers=4, shuffle=True),
    'test': data_utils.DataLoader(image_datasets['test'], batch_size=batch_size, num_workers=4, shuffle=True)
}
# num_workers=?, drop_last=True

# Detect if we have a GPU available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

Number of training patches: 258108 
Number of validation patches 58768 
Number of test patches 47709


In [27]:
# Check format of data
print(image_datasets['train'][0][0].size())

torch.Size([3, 256, 256])


In [None]:
# Check device
print(device)

### Create optimiser

In [None]:
# Set model parameters
learning_rate = 0.001
learning_rate_decay = 0.0000001

In [None]:
# Send the model to GPU
CNN_model = CNN_model.to(device)

optimiser = optim.SGD(CNN_model.parameters(), lr=learning_rate, weight_decay=learning_rate_decay)

### Model

In [None]:
WANDB_NOTEBOOK_NAME = 'simple_model'

In [None]:
wandb.login()

In [None]:
# Initialize WandB 
run = wandb.init(
    # Set the project where this run will be logged
    project="masters",
    notes="Practice run, simple 6-layer CNN",
    # Track hyperparameters and run metadata
    config={
        "batch_size": batch_size,
        "learning_rate": learning_rate,
        "weight_decay": learning_rate_decay,
        "epochs": num_epochs,
    })

In [None]:
progress = {'train': tqdm(total=len(dataloaders['train']), desc="Training progress"), 'val': tqdm(total=len(dataloaders['val']), desc="Validation progress")}

In [None]:
# Setup the loss fxn
criterion = nn.CrossEntropyLoss()

# Train and evaluate
CNN_model = train_model(CNN_model, dataloaders, progress, criterion, optimiser, num_epochs=num_epochs)

In [None]:
# Save model!!

In [None]:
# Test model

In [None]:
# Make predictions on example image