In [1]:
!pip install -U PyYAML

Collecting PyYAML
  Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)
[K     |████████████████████████████████| 596 kB 27.5 MB/s 
[?25hInstalling collected packages: PyYAML
  Attempting uninstall: PyYAML
    Found existing installation: PyYAML 3.13
    Uninstalling PyYAML-3.13:
      Successfully uninstalled PyYAML-3.13
Successfully installed PyYAML-6.0


In [2]:
#@title Mount your Google Drive
# If you run this notebook locally or on a cluster (i.e. not on Google Colab)
# you can delete this cell which is specific to Google Colab. You may also
# change the paths for data/logs in Arguments below.
%matplotlib inline
%load_ext autoreload
%autoreload 2
''
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [3]:
#@title Link your assignment folder & install requirements
#@markdown Enter the path to the assignment folder in your Google Drive
# If you run this notebook locally or on a cluster (i.e. not on Google Colab)
# you can delete this cell which is specific to Google Colab. You may also
# change the paths for data/logs in Arguments below.
import sys
import os
import shutil
import warnings

folder = "/content/gdrive/MyDrive/MILA/IFT6759/Project/src" #@param {type:"string"}
!ln -Ts "$folder" /content/src 2> /dev/null

# Add the assignment folder to Python path
if '/content/src' not in sys.path:
  sys.path.insert(0, '/content/src')

# Install requirements
#!pip install -qr /content/assignment/requirements.txt

# Check if CUDA is available
import torch
if not torch.cuda.is_available():
  warnings.warn('CUDA is not available.')

In [4]:
import argparse
import logging
import os
import torch
import yaml
import numpy as np
import pickle
import logging
import torch.nn.functional as F


In [5]:
config_name = "Sup_Config1"

with open(f"./src/Config/{config_name}.yaml") as f:
    config = yaml.load(f, Loader=yaml.FullLoader)


In [6]:
device = "cuda"
task = config["task"]
data_file = config["data"]
model_file = config["model"]
augment_file = config["augment"]
augment_strength = config["aug_strength"]
eval_file = config["eval"]
batch_size = config["batch_size"]
learn_rate = config["learning_rate"]
epoch = config["epoch"]
optimizer = config["optimizer"]
#momentum = config["momentum"]
weight_decay = config["weight_decay"]
seed = config["seed"]
epoch = config["epoch"]

In [7]:
logging.info(f"==========Dataset: {data_file}==========")
data_file_path = f"Data.{data_file}"
_temp = __import__(name=data_file_path, fromlist=['Data_Load'])
Data_Load = _temp.Data_Load

In [8]:
if augment_file == None:
    print("No augmentation method selected")
else:
    Aug = []
    for i in range(len(augment_file)):
        logging.info(f"==========Augmentation Methods: {augment_file[i]}, with a strength value of {augment_strength[i]}==========")
        augment_file_path = f"Augmentation.{augment_file[i]}"
        _temp = __import__(name=augment_file_path, fromlist=['Aug'])
        Aug.append(_temp.Aug)

In [9]:
# Importing the model class
logging.info(f"==========Model Selected: {model_file}==========")
model_file_path = f"Model.{model_file}"
_temp = __import__(name=model_file_path, fromlist=['ModelClass'])
ModelClass = _temp.ModelClass

In [10]:
# Importing the evaluation methods
logging.info(f"==========Evaluation Method: {eval_file}==========")
eval_file_path = f"Evaluation.{eval_file}"
_temp = __import__(name=eval_file_path, fromlist=['Eval'])
Eval = _temp.Eval  

In [11]:
labelledloader, unlabelledloader, validloader, testloader = Data_Load(task = task, batch_size = batch_size, seed = seed)
logging.info("Dataloader ready")    

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./Data/Cifar10_Data/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting ./Data/Cifar10_Data/cifar-10-python.tar.gz to ./Data/Cifar10_Data
Files already downloaded and verified


In [None]:
if (task == "super"):
  torch.manual_seed(seed)

  train_tot_accs, valid_tot_accs = [], []
  train_tot_losses, valid_tot_losses = [], []

  Model = ModelClass(optimizer=optimizer,lr=learn_rate,weight_decay=weight_decay)
  Model = Model.to(device=device)

  for ep in range(epoch):
      
      # logging.info(f"==========Supervised Learning Epoch Number: {ep+1}/{epoch}==========")
      print(f"==========Supervised Learning Epoch Number: {ep+1}/{epoch}==========")
      train_accs, valid_accs = [], []
      train_losses, valid_losses = [], []
      
      for idx, batch in enumerate(labelledloader):
          data, target = batch
          data = data.to(device=device)
          labels = F.one_hot(target, num_classes = 10).float().to(device=device)

          batch_len = data.shape[0]
          aug_num = []
          
          if augment_file != None:

            for i in range(len(augment_strength)):
              aug_num.append(augment_strength[i]*batch_len)
            
            if len(aug_num) != 1:
              aug_num = torch.cat(aug_num)
              aug_ind = torch.cumsum(aug_num)
            else:
              aug_ind = aug_num

            for i in range(len(Aug)):
              if i == 0:
                temp_Aug, temp_label = Aug[i](data[0:aug_ind[i]],labels[0:aug_ind[i]],torch.rand(1))
                Aug_data = temp_Aug
                Aug_labels = temp_label
              else:
                temp_Aug, temp_label = Aug[i](data[aug_ind[i-1]:aug_ind[i]],labels[aug_ind[i-1]:aug_ind[i]],torch.rand(1))
                Aug_data = torch.cat((Aug_data, temp_Aug), 0)
                Aug_labels = torch.cat((Aug_labels, temp_label), 0)

          else:
            Aug_data = torch.cat((data,data,data,data),0)
            Aug_labels = torch.cat((labels,labels,labels,labels),0)

          acc, loss = Model.train_sup_up(Aug_data,Aug_labels)
          train_accs.append(acc)
          train_losses.append(loss)
      
      train_tot_accs.append(sum(train_accs)/len(train_accs))
      train_tot_losses.append(sum(train_losses)/len(train_losses))
          
      # logging.info(f"==========Training Accuracy: {train_tot_accs[-1]:.3f} , Training Loss: {train_tot_losses[-1]:.3f}==========")    
      print(f"==========Training Accuracy: {train_tot_accs[-1]:.3f} , Training Loss: {train_tot_losses[-1]:.3f}==========")

      for idx, batch in enumerate(validloader):
          data, target = batch
          data = data.to(device=device)
          labels = F.one_hot(target, num_classes = 10).float().to(device=device)
          acc, loss = Model.evaluation(data,labels)
          valid_accs.append(acc)
          valid_losses.append(loss)
          
      valid_tot_accs.append(sum(valid_accs)/len(valid_accs))
      valid_tot_losses.append(sum(valid_losses)/len(valid_losses))
      
      # logging.info(f"==========Validation Accuracy: {valid_tot_accs[-1]:.3f} , Validation Loss: {valid_tot_losses[-1]:.3f}==========")    
      print(f"==========Validation Accuracy: {valid_tot_accs[-1]:.3f} , Validation Loss: {valid_tot_losses[-1]:.3f}==========")



In [None]:
if (task == "semi"):

  torch.manual_seed(seed)

  # Training the labeller
  lab_train_tot_accs, lab_valid_tot_accs = [], []
  lab_train_tot_losses, lab_valid_tot_losses = [], []

  Labeller = ModelClass(optimizer=optimizer,lr=learn_rate,weight_decay=weight_decay)
  Labeller = Model.to(device=device)
  
  for ep in range(epoch):
            
      # logging.info(f"==========Semi-supervised Learning Labeller Epoch Number: {ep+1}/epoch==========")
      print(f"==========Semi-supervised Learning Labeller Epoch Number: {ep+1}/{epoch}==========")
      train_accs, valid_accs = [], []
      train_losses, valid_losses = [], []
      
      for idx, batch in enumerate(labelledloader):
          data, target = batch
          data = data.to(device=device)
          labels = F.one_hot(target, num_classes = 10).float().to(device=device)
          batch_len = data.shape[0]
          aug_num = []

          if augment_file != None:

            for i in range(len(augment_strength)):
              aug_num.append(augment_strength[i]*batch_len)

            if len(aug_num) != 1:
              aug_num = torch.cat(aug_num)
              aug_ind = torch.cumsum(aug_num)
            else:
              aug_ind = aug_num

            for i in range(len(Aug)):
              if i == 0:
                temp_Aug, temp_label = Aug[i](data[0:aug_ind[i]],labels[0:aug_ind[i]],torch.rand(1))
                Aug_data = temp_Aug
                Aug_labels = temp_label
              else:
                temp_Aug, temp_label = Aug[i](data[aug_ind[i-1]:aug_ind[i]],labels[aug_ind[i-1]:aug_ind[i]],torch.rand(1))
                Aug_data = torch.cat((Aug_data, temp_Aug), 0)
                Aug_labels = torch.cat((Aug_labels, temp_label), 0)

          else:
            Aug_data = torch.cat((data,data,data,data),0)
            Aug_labels = torch.cat((labels,labels,labels,labels),0)


          acc, loss = Labeller.train_sup_up(Aug_data,Aug_labels)
          train_accs.append(acc)
          train_losses.append(loss)
      
      lab_train_tot_accs.append(sum(train_accs)/len(train_accs))
      lab_train_tot_losses.append(sum(train_losses)/len(train_losses))
          
      # logging.info(f"==========Training Accuracy: {lab_train_tot_accs[-1]:.3f} , Training Loss: {lab_train_tot_losses[-1]:.3f}==========")    
      print(f"==========Training Accuracy: {lab_train_tot_accs[-1]:.3f} , Training Loss: {lab_train_tot_losses[-1]:.3f}==========")

      for idx, batch in enumerate(validloader):
          data, target = batch
          data = data.to(device=device)
          labels = F.one_hot(target, num_classes = 10).float().to(device=device)
          acc, loss = Labeller.evaluation(data,labels)
          valid_accs.append(acc)
          valid_losses.append(loss)

      lab_valid_tot_accs.append(sum(valid_accs)/len(valid_accs))
      lab_valid_tot_losses.append(sum(valid_losses)/len(valid_losses))
      
      # logging.info(f"==========Validation Accuracy: {lab_valid_tot_accs[-1]:.3f} , Validation Loss: {lab_valid_tot_losses[-1]:.3f}==========")    
      print(f"==========Validation Accuracy: {lab_valid_tot_accs[-1]:.3f} , Validation Loss: {lab_valid_tot_losses[-1]:.3f}==========")
  



          
  # Accuracy and loss when predicting the unlabelled data
  lab_accs = []
  lab_losses = []
  
  for idx, batch in enumerate(unlabelledloader):
      data, target = batch
      data = data.to(device=device)
      labels = F.one_hot(target, num_classes = 10).float().to(device=device)
      acc, loss = Labeller.evaluation(data,labels)
      lab_accs.append(acc)
      lab_losses.append(loss)
          
  lab_tot_accs = (sum(lab_accs)/len(lab_accs))
  lab_tot_losses = (sum(lab_losses)/len(lab_losses))
              
  logging.info(f"==========Labelled Accuracy: {lab_tot_accs:.3f} , Labelled Loss: {lab_tot_losses:.3f}==========")           

  
  # Train the final model with labelled data and unlabelled data where the target is predicted by the the labeller
  train_tot_accs, valid_tot_accs = [], []
  train_tot_losses, valid_tot_losses = [], []

  Model = ModelClass(optimizer=optimizer,lr=learn_rate,weight_decay=weight_decay)
  Model = Model.to(device=device)
  
  for ep in range(epoch):
      
      logging.info(f"==========Semi-supervised Learning Model Epoch Number: {ep+1}/epoch==========")
      train_accs, valid_accs = [], []
      train_losses, valid_losses = [], []
      
      for idx, batch in enumerate(labelledloader):
          data, target = batch
          acc, loss = Model.Train(data,target)
          train_accs.append(acc)
          train_losses.append(loss)

      for idx, batch in enumerate(unlabelledloader):
          data, _ = batch
          data = data.to(device=device)
          labels = Labeller.forward(data)

          batch_len = data.shape[0]
          aug_num = []

          if augment_file != None:

            for i in range(len(augment_strength)):
              aug_num.append(augment_strength[i]*batch_len)

            if len(aug_num) != 1:
              aug_num = torch.cat(aug_num)
              aug_ind = torch.cumsum(aug_num)
            else:
              aug_ind = aug_num

            for i in range(len(Aug)):
              if i == 0:
                temp_Aug, temp_label = Aug[i](data[0:aug_ind[i]],labels[0:aug_ind[i]],torch.rand(1))
                Aug_data = temp_Aug
                Aug_labels = temp_label
              else:
                temp_Aug, temp_label = Aug[i](data[aug_ind[i-1]:aug_ind[i]],labels[aug_ind[i-1]:aug_ind[i]],torch.rand(1))
                Aug_data = torch.cat((Aug_data, temp_Aug), 0)
                Aug_labels = torch.cat((Aug_labels, temp_label), 0)

          else:
            Aug_data = torch.cat((data,data,data,data),0)
            Aug_labels = torch.cat((labels,labels,labels,labels),0)


          acc, loss = Model.train_sup_up(Aug_data,Aug_labels)
          train_accs.append(acc)
          train_losses.append(loss)
      
      train_tot_accs.append(sum(train_accs)/len(train_accs))
      train_tot_losses.append(sum(train_losses)/len(train_losses))
          
      # logging.info(f"==========Training Accuracy: {train_tot_accs[-1]:.3f} , Training Loss: {train_tot_losses[-1]:.3f}==========")    
      print(f"==========Training Accuracy: {train_tot_accs[-1]:.3f} , Training Loss: {train_tot_losses[-1]:.3f}==========")   




      for idx, batch in enumerate(validloader):
          data, target = batch
          data = data.to(device=device)
          labels = F.one_hot(target, num_classes = 10).float().to(device=device)
          acc, loss = Model.evaluation(data,labels)
          valid_accs.append(acc)
          valid_losses.append(loss)

      valid_tot_accs.append(sum(valid_accs)/len(valid_accs))
      valid_tot_losses.append(sum(valid_losses)/len(valid_losses))
      
      # logging.info(f"==========Validation Accuracy: {valid_tot_accs[-1]:.3f} , Validation Loss: {valid_tot_losses[-1]:.3f}==========")    
      print(f"==========Validation Accuracy: {valid_tot_accs[-1]:.3f} , Validation Loss: {valid_tot_losses[-1]:.3f}==========")
  
  



In [None]:
import matplotlib.pyplot as plt
plt.plot(torch.stack(train_tot_accs).cpu().detach().numpy())

plt.plot(torch.stack(valid_tot_accs).cpu().detach().numpy())


In [None]:

plt.plot(torch.stack(train_tot_losses).cpu().detach().numpy())
plt.plot(torch.stack(valid_tot_losses).cpu().detach().numpy())

In [None]:
test_accs = []
test_losses = []

for idx, batch in enumerate(testloader):
    data, target = batch
    data = data.to(device=device)
    labels = F.one_hot(target, num_classes = 10).float().to(device=device)
    acc, loss = Model.evaluation(data,labels)
    test_accs.append(acc)
    test_losses.append(loss)
        
test_tot_accs = (sum(test_accs)/len(test_accs))
test_tot_losses = (sum(test_losses)/len(test_losses))
            
# logging.info(f"==========Test Accuracy: {test_tot_accs:.3f} , Test Loss: {test_tot_losses:.3f}==========") 
print(f"==========Test Accuracy: {test_tot_accs:.3f} , Test Loss: {test_tot_losses:.3f}==========")  

In [None]:
with open(f"./src/Evaluation/Logs/{config_name}_train_acc.txt", "w") as f:
        f.write("\n".join(str(item.cpu().detach().numpy()) for item in train_tot_accs))
with open(f"./src/Evaluation/Logs/{config_name}_train_loss.txt", "w") as f:
        f.write("\n".join(str(item.cpu().detach().numpy()) for item in train_tot_losses))
with open(f"./src/Evaluation/Logs/{config_name}_valid_acc.txt", "w") as f:
        f.write("\n".join(str(item.cpu().detach().numpy()) for item in valid_tot_accs))
with open(f"./src/Evaluation/Logs/{config_name}_valid_loss.txt", "w") as f:
        f.write("\n".join(str(item.cpu().detach().numpy()) for item in valid_tot_losses))
with open(f"./src/Evaluation/Logs/{config_name}_test_acc.txt", "w") as f:
        f.write("\n".join(str(item.cpu().detach().numpy()) for item in [test_tot_accs]))
with open(f"./src/Evaluation/Logs/{config_name}_test_loss.txt", "w") as f:
        f.write("\n".join(str(item.cpu().detach().numpy()) for item in [test_tot_losses]))

In [None]:
# pickle_path = f"./Model/{args.config_file}.pickle"
pickle_path = f"./src/Model/{config_name}.pickle"
logging.info("Saving model to pickle file")
with open(pickle_path, "wb") as f:
  pickle.dump(Model, f, pickle.HIGHEST_PROTOCOL)


