In [None]:
# Mount the Google drive to the virtual machine and change the directory
from google.colab import drive
drive.mount('/content/drive')
import os
os.chdir('drive/MyDrive/PRJ')

In [None]:
!pip install pytorch-metric-learning --pre
!pip install faiss-gpu
from torchvision import transforms
from pytorch_metric_learning import miners, losses
from pytorch_metric_learning.utils.accuracy_calculator import AccuracyCalculator
import torch
import torch.optim as optim
import CIFAR10
import CIFAR100
import FashionMNIST
import resnet18
import utils
import pickle

# Deep metric learning related functions
mining_func = miners.TripletMarginMiner(margin = 0.2, type_of_triplets = "semihard")
loss_func = losses.TripletMarginLoss(margin = 0.2)
accuracy_calculator = AccuracyCalculator(include = ("mean_average_precision_at_r","AMI","NMI"), k = 10)

# Neural network realted varaibles
BATCH_SIZE = 256
EPOCHS = {"CIFAR10":60,"CIFAR100":300,"FashionMNIST":60}
NORMAL_SAMPLE_RATE = 1 # Initial sample rate
NORMAL_SAMPLE_RATE_DECREMENT = 0.2
# The datasets used to do the experiment
datasets = ["CIFAR10","FashionMNIST","CIFAR100"]
# The augmentation strategy for the training data(0 or 1)
trans_index = 0

for dataset_name in datasets:
    normal_sample_rate = NORMAL_SAMPLE_RATE
    while normal_sample_rate > 0:

        device = torch.device("cuda")
        model = resnet18.Net(utils.get_param(dataset_name)).to(device)
        optimizer = optim.Adam(model.parameters(), lr=0.001)
        
        # Clear the parameters and caches
        model.zero_grad() 
        optimizer.zero_grad()
        torch.cuda.empty_cache()
        
        # Get the training and testing datasets
        train_dataset = utils.get_dataset(dataset_name,normal_sample_rate,True,trans_index)
        train_loader = torch.utils.data.DataLoader(train_dataset, BATCH_SIZE, shuffle=True)

        test_dataset = utils.get_dataset(dataset_name,normal_sample_rate,False,None)
        test_loader = torch.utils.data.DataLoader(test_dataset, BATCH_SIZE)
        
        # The information recorded from one complete run
        sample_data = list() 
        epochs = EPOCHS[dataset_name]
        for epoch in range(1, epochs+1):
            # The training and testing functions
            epoch_data = utils.train(model, loss_func, mining_func, train_loader, optimizer, epoch, device)
            accuracies = utils.test(train_dataset, test_dataset, model, accuracy_calculator)
            for accuracy in accuracies:
                epoch_data.append(accuracy)
            sample_data.append(epoch_data)

        # Save the data to a local file
        data_file_name = "./output/"+ dataset_name +"_"+ str(round(normal_sample_rate, 1)) + ".pkl"
        open_file = open(data_file_name, "wb")
        pickle.dump(sample_data, open_file)
        open_file.close()

        # Save the checkpoint to a local file(uncomment to use)  
        
        # parameter_file_name = folder_name +'checkpoint' + str(normal_sample_rate) +'.pth'
        # torch.save(model.state_dict(),parameter_file_name)

        # Decrement of the normal sample rate 
        normal_sample_rate = normal_sample_rate - NORMAL_SAMPLE_RATE_DECREMENT