<h1> ECE4179 - Semi-Supervised Learning Project</h1>
<h2>FixMatch</h2>

In [157]:
import torch
from torchvision.datasets import STL10 as STL10
import torchvision.transforms as transforms
from torch.utils.data import random_split
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision
import torch.optim as optim
from randaugment import RandAugmentMC
import multiprocessing
import numpy as np
import matplotlib.pyplot as plt

from IPython.display import clear_output
import os
import time
import copy

# For rotation
from PIL import Image
from typing import Any, Callable, Optional, Tuple

####### CHANGE TO APPROPRIATE DIRECTORY TO STORE DATASET
dataset_dir = os.getcwd()+"/CNN-VAE/data"
#For MonARCH
# dataset_dir = "/mnt/lustre/projects/ds19/SHARED"

#All images are 3x96x96
image_size = 96
#Example batch size
batch_size = 24
# Number of classes for classification
out_classes = 10
start_epoch = 0
n_epochs = 5000
lr = 2e-4
#No schedular used

#temp = 0.5 Pseudo label temperature
mu = 10 #Coefficient of unlabelled batch size
threshold = 0.95 #Pseudo label threshold
lambda_u = 2  #Coefficient of unlabelled loss

start_from_checkpoint = False
save_dir = 'Models'
model_name = 'FixMatch'
    
# Hardware acceleration
GPU_indx = 0
device = torch.device(GPU_indx if torch.cuda.is_available() else 'cpu')

<h3>Create the appropriate transforms</h3>

In [158]:
#Perform random crops and mirroring for data augmentation
transform_train = transforms.Compose([
    transforms.RandomCrop(image_size, padding=12, padding_mode='reflect'),
    transforms.RandomHorizontalFlip(p=0.5), 
    transforms.RandomAffine(10, translate=(0.1,0.1)),
    #RandAugmentMC(n=1, m=3),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

#No random 
transform_test = transforms.Compose([
    transforms.CenterCrop(image_size),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

stl10_mean = (0.485, 0.456, 0.406)
stl10_std = (0.229, 0.224, 0.225)


class TransformFix(object):
    def __init__(self, mean, std):
        self.weak = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(size=96,
                                  padding=int(96*0.125),
                                  padding_mode='reflect')])
        self.strong = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(size=96,
                                  padding=int(96*0.125),
                                  padding_mode='reflect'),
            RandAugmentMC(n=2, m=10)])
        self.normalize = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=mean, std=std)])

    def __call__(self, x):
        weak = self.weak(x)
        strong = self.strong(x)
        return self.normalize(weak), self.normalize(strong)

<h3>Create training and validation split</h3>

In [56]:
#Load train and validation sets
trainval_set = STL10(dataset_dir, split='train', transform=transform_train, download=True)

#Use 10% of data for training - simulating low data scenario
num_train = int(len(trainval_set)*0.1)

#Split data into train/val sets
torch.manual_seed(0) #Set torch's random seed so that random split of data is reproducible
train_set, val_set = random_split(trainval_set, [num_train, len(trainval_set)-num_train])

#Load test set
test_set = STL10(dataset_dir, split='test', transform=transform_test, download=True)

Files already downloaded and verified
Files already downloaded and verified


<h3>Get the unlabelled data</h3>

In [57]:
unlabelled_set = STL10(dataset_dir, split='unlabeled', transform=TransformFix(mean=stl10_mean, std=stl10_std), download=True)

Files already downloaded and verified


You may find later that you want to make changes to how the unlabelled data is loaded. This might require you sub-classing the STL10 class used above or to create your own dataloader similar to the Pytorch one.
https://pytorch.org/docs/stable/_modules/torchvision/datasets/stl10.html#STL10

<h3>Create the four dataloaders</h3>

In [159]:
train_loader = DataLoader(train_set, shuffle=True, batch_size=batch_size, drop_last=True)
unlabelled_loader = DataLoader(unlabelled_set, shuffle=True, batch_size=batch_size*mu, drop_last=True)

valid_loader = DataLoader(val_set, batch_size=batch_size, drop_last=True)
test_loader  = DataLoader(test_set, batch_size=batch_size, drop_last=True)

## Network

Let's use a ResNet18 architecture for our CNN...

In [160]:
net = torchvision.models.resnet18(pretrained=False).to(device)
net_fc_in = net.fc.in_features
net.fc = nn.Linear(net_fc_in, out_classes).to(device)
optimizer = optim.Adam(net.parameters(), lr=lr, betas=(0.9, 0.999))

In [161]:
labeled_iter = iter(train_loader)
unlabeled_iter = iter(unlabelled_loader)

In [162]:
class_labels = ["airplane", "bird", "car", "cat", "deer", "dog", "horse", "monkey", "ship", "truck"]

In [163]:
#This Function will allow us to scale an images pixel values to a value between 0 and 1
def normalize_img(img):
    mins = img.min(0, keepdims = True).min(1, keepdims = True)
    maxs = img.max(0, keepdims = True).max(1, keepdims = True)
    return (img - mins)/(maxs - mins)

#Visualise strongly augmented unlabeled images
# plt.figure(figsize = (15,10))
# (image_batch_w, image_batch_s),_ = unlabeled_iter.next()
# for tmpC1 in range(8):    
#     img = np.moveaxis(image_batch_s[tmpC1].numpy(),0,2)
#     plt.subplot(2,4,tmpC1+1)
#     plt.imshow(img)

In [164]:
def calculate_accuracy(fx, y):
    preds = fx.max(1, keepdim=True)[1]
    correct = preds.eq(y.view_as(preds)).sum()
    acc = correct.float()/preds.shape[0]
    return acc

In [165]:
def interleave(x, size):
    s = list(x.shape)
    return x.reshape([-1, size] + s[1:]).transpose(0, 1).reshape([-1] + s[1:])

def de_interleave(x, size):
    s = list(x.shape)
    return x.reshape([size, -1] + s[1:]).transpose(0, 1).reshape([-1] + s[1:])


In [166]:
#Create Save Path from save_dir and model_name, we will save and load our checkpoint here
Save_Path = os.path.join(save_dir, model_name + ".pt")

#Create the save directory if it does not exist
if not os.path.isdir(save_dir):
    os.makedirs(save_dir)

#Load Checkpoint
if start_from_checkpoint:
    #Check if checkpoint exists
    if os.path.isfile(Save_Path):
        #load Checkpoint
        check_point = torch.load(Save_Path)
        #Checkpoint is saved as a python dictionary
        #https://www.w3schools.com/python/python_dictionaries.asp
        #here we unpack the dictionary to get our previous training states
        net.load_state_dict(check_point['model_state_dict'])
        optimizer.load_state_dict(check_point['optimizer_state_dict'])
        start_epoch = check_point['epoch']
        best_valid_acc = check_point['valid_acc']
        print("Checkpoint loaded, starting from epoch:", start_epoch)
    else:
        #Raise Error if it does not exist
        raise ValueError("Checkpoint Does not exist")
else:
    #If checkpoint does exist and Start_From_Checkpoint = False
    #Raise an error to prevent accidental overwriting
    if os.path.isfile(Save_Path):
        print("Warning Checkpoint exists")
    else:
        print("Starting from scratch")



In [167]:
def train(net, device, labeled_trainloader, unlabeled_trainloader, test_loader, optimizer, loss_logger, acc_logger, mu, threshold, lambda_u):
    #Train mode
    net.train()
    try:
        inputs_x, targets_x = labeled_iter.next()
    except:
        labeled_iter = iter(labeled_trainloader)
        inputs_x, targets_x = labeled_iter.next()
    try:
        (inputs_u_w, inputs_u_s),_ = unlabeled_iter.next()
    except:
        unlabeled_iter = iter(unlabeled_trainloader)
        (inputs_u_w, inputs_u_s),_ = unlabeled_iter.next()

    batch_size = inputs_x.shape[0]
    inputs_u_w = torch.FloatTensor(inputs_u_w)
    inputs_u_s = torch.FloatTensor(inputs_u_s)
    inputs = interleave(torch.cat((inputs_x, inputs_u_w, inputs_u_s)), 2*mu+1).to(device)
    # test = torch.cat((inputs_x, inputs_u_w, inputs_u_s))
    # plt.figure(figsize = (20,10))
    # for i in np.linspace(0,120,11, dtype=int).tolist():    
    #     img = np.moveaxis(test[i].numpy(),0,2)
    #     plt.subplot(3,5,(i//12)+1)
    #     plt.imshow(img)
    targets_x = targets_x.to(device)
    logits = net(inputs)
    logits = de_interleave(logits, 2*mu+1)
    logits_x = logits[:batch_size]
    logits_u_w, logits_u_s = logits[batch_size:].chunk(2)
    del logits

    Lx = F.cross_entropy(logits_x, targets_x, reduction='mean')

    pseudo_label = torch.softmax(logits_u_w.detach_(), dim=-1)
    max_probs, targets_u = torch.max(pseudo_label, dim=-1)
    mask = max_probs.ge(threshold).float()
    Lu = (F.cross_entropy(logits_u_s, targets_u, reduction='none') * mask).mean()
    #Compute total loss
    loss = Lx + lambda_u * Lu
    #calc acc
    acc = calculate_accuracy(logits_x, targets_x)
    #zero gradients
    optimizer.zero_grad()
    #Backpropagate Gradents
    loss.backward()
    #Do a single optimization step
    optimizer.step()
    #log the loss and acc for plotting
    loss_logger.append(loss.item())
    acc_logger.append(acc.item())
    pseudo_label_cnt = torch.sum(mask)
    clear_output(True)       
    return Lx.item(), Lu.item(), loss_logger, pseudo_label_cnt, acc_logger

In [168]:
def evaluate(net, device, loader, Loss_fun, loss_logger = None):
    epoch_loss = 0
    epoch_acc = 0
    #Set network in evaluation mode
    net.eval()
    
    with torch.no_grad():
        for i, (x, y) in enumerate(loader):
            x = x.to(device)
            y = y.to(device)
            #Forward pass
            fx = net(x)
            #Calculate loss
            loss = Loss_fun(fx, y.type(torch.LongTensor).to(device))
            #calculate the accuracy
            acc = calculate_accuracy(fx, y)
            #log the cumulative sum of the loss and acc
            epoch_loss += loss.item()
            epoch_acc += acc.item()
            #log the loss for plotting if we passed a logger to the function
            if not (loss_logger is None):
                loss_logger.append(loss.item())
            print("EVALUATION: | Itteration [%d/%d] | Loss %.2f | Accuracy %.2f |" %(i+1 ,len(loader), loss.item(), 100*(epoch_acc/(i+1))))
            clear_output(True)
            if i==20:
                break
    #return the avaerage loss and acc from the epoch as well as the logger array       
    return epoch_loss/i, epoch_acc/i, loss_logger

In [169]:
cycle_loss = 0
cycle_loss_av = 0
pseudo_label_cnt = 0
pseudo_label_cnt_log = []
test_accs = []
training_loss_logger = []
Lx_logger = []
Lu_logger = []
training_acc_logger = []
train_loss_average = []
cycle_loss_log = []
validation_acc_logger = []
validation_loss_logger = []
loss_fn = nn.CrossEntropyLoss()
net.zero_grad()

for epoch in range(start_epoch, n_epochs):
    Lx, Lu, training_loss_logger, pseudo_label_cnt, training_acc_logger = train(net=net, device=device, labeled_trainloader=train_loader, unlabeled_trainloader=unlabelled_loader,   test_loader=test_loader, optimizer=optimizer, loss_logger=training_loss_logger, acc_logger=training_acc_logger, mu=mu, threshold=threshold, lambda_u=lambda_u)
    if epoch%20==0:
        valid_loss, valid_acc, validation_loss_logger = evaluate(net, device, valid_loader, loss_fn, validation_loss_logger)
    validation_acc_logger.append(valid_acc)
    pseudo_label_cnt = pseudo_label_cnt.cpu().numpy()
    pseudo_label_cnt_log.append(pseudo_label_cnt)
    Lx_logger.append(Lx)
    Lu_logger.append(Lu*2)
    if epoch%20==0:
        plt.figure(figsize = (15,10))
        plt.subplot(2,3,1)
        plt.title('Train Loss Total')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.plot(training_loss_logger, c = "b")
        plt.subplot(2,3,2)
        plt.title('Train Accuracy')
        plt.xlabel('Epoch')
        plt.ylabel('Accuracy (%)')
        plt.plot(training_acc_logger, c = "b")
        plt.subplot(2,3,3)
        plt.title('Validation Accuracy')
        plt.xlabel('Epoch')
        plt.ylabel('Accuracy (%)')
        plt.plot(validation_acc_logger, c = "b")
        plt.subplot(2,3,4)
        plt.title('Contributing Pseudo Labels')
        plt.xlabel('Epoch')
        plt.ylabel('Count per batch')
        plt.plot(pseudo_label_cnt_log, c = "b")
        plt.subplot(2,3,5)
        plt.title('Loss: Labeled')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.plot(Lx_logger, c = "b")
        plt.subplot(2,3,6)
        plt.title('Loss: Unlabeled')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.plot(Lu_logger, c = "b")
        plt.show()

KeyboardInterrupt: 