In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision.models as models
from torchvision.datasets import ImageFolder
import torchvision.transforms as transforms
from PIL import Image
import pandas as pd
import numpy as np
import os
import time
from torch.utils.data import random_split
from torch.optim import lr_scheduler
from tqdm import tqdm

In [None]:
# class to read dataset with tuples of image and its name (for eval&test in a later stage)
class MYDATA_labeled(Dataset):

    def __init__(self, root, label_csv, transform=None):
        self.transform = transform
        self.target_transform = None
        df = pd.read_csv(label_csv)
        self.img_path_list = []
        self.target_list = []
        for i in range(df.shape[0]):
            self.img_path_list.append(os.path.join(root, df.iloc[i,0]))
            self.target_list.append(df.iloc[i,1]-1)

    def __getitem__(self, index):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        img, target = Image.open(self.img_path_list[index]).convert("RGB"), self.target_list[index]

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target
    
    def __len__(self):
        return len(self.img_path_list)

In [None]:
class EarlyStopping:
    def __init__(self, patience=7, verbose=False, delta=0):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.val_acc_max = 0.0
        self.delta = delta

    def __call__(self, val_loss, val_acc, model):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model, "best_loss")
        elif score < self.best_score + self.delta:
            self.counter += 1
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model, "best_loss")
            self.counter = 0
        
        if self.val_acc_max < val_acc:
            self.val_acc_max = val_acc
            self.save_checkpoint(val_loss, model, "best_acc")


    def save_checkpoint(self, val_loss, model, name):
        if self.verbose:
            print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...')
        torch.save(model.state_dict(), os.path.join(savedir,'res_{}_checkpoint.pt'.format(name)))
        self.val_loss_min = val_loss

In [None]:
transform = {
    "train":transforms.Compose([
        # transforms.RandomCrop((224, 224), pad_if_needed=True), 
        transforms.Resize((224, 224)), 
        transforms.RandomHorizontalFlip(p = 0.5),
        transforms.RandomRotation(10),
        # transforms.RandomVerticalFlip(p=0.5),
        # transforms.ColorJitter(brightness=[0,0.5], contrast=[0,0.5], saturation=[0,0.5], hue=[0,0.5]),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.4355, 0.4345, 0.4332], std=[0.2830, 0.2826, 0.2818])
    ]),
    "val":transforms.Compose([
        transforms.Resize((224, 224)), 
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.4355, 0.4345, 0.4332], std=[0.2830, 0.2826, 0.2818])
    ])
}

In [None]:
# Load labeled dataset
root_dir = "."
# labeled_dir = os.path.join(root_dir, "1k_label")
# labeled_dir = os.path.join(root_dir, "1k6_label")
labeled_dir = os.path.join(root_dir, "3k_label")

unlabeled_dir = os.path.join(root_dir, "29k_unlabel")
val_dir = os.path.join(root_dir, "testset_4k5")

# train_label_csv = os.path.join(root_dir, "1k_true_label_csv.csv")
# train_label_csv = os.path.join(root_dir, "1k6_true_label_csv.csv")
train_label_csv = os.path.join(root_dir, "all_3k_train_true_label_csv.csv")
test_label_csv = os.path.join(root_dir, "testset_4k5_true_label_csv.csv")

train_dataset = MYDATA_labeled(labeled_dir, train_label_csv, transform=transform["train"])
val_dataset = MYDATA_labeled(val_dir, test_label_csv, transform=transform["val"])

# Create DataLoaders for the training and validation datasets
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=4)
val_dataloader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=4)

In [None]:
# save model
savedir = os.path.join(root_dir, "checkpoint")
if not os.path.exists(savedir):
    os.makedirs(savedir)
exp_id = len(os.listdir(savedir)) + 1
savedir = os.path.join(savedir,"run{}".format(exp_id))
if not os.path.exists(savedir):
    os.makedirs(savedir)

In [None]:
# load pretrained model
model = models.resnet50(pretrained=True)
feature = model.fc.in_features
model.fc = nn.Linear(in_features=feature,out_features=8,bias=True)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

In [None]:
# hyperparams
loss_f = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.00001)
# base_model = [n for n in list(model.parameters()) if n not in list(model.classifier.parameters())]
# optimizer = torch.optim.SGD([
#         {"params":model.fc.parameters()},
#         {"params":model.classifier.parameters(),"lr":1e-2}
#         ],lr = 1e-4, momentum=0.9)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5)
# initialize the early_stopping object
early_stopping = EarlyStopping(patience=10, verbose=True)

num_epochs = 100  # set it to a large number

In [None]:
for epoch in range(num_epochs):
    # Training phase
    exp_lr_scheduler.step()
    model.train()
    running_train_loss = 0.0
    running_train_corrects = 0
    print("trainning...")
    for inputs, labels in train_dataloader:
        inputs = inputs.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()

        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)
        loss = loss_f(outputs, labels)

        loss.backward()
        optimizer.step()

        running_train_loss += loss.item() * inputs.size(0)
        running_train_corrects += torch.sum(preds == labels.data)

    # Validation phase
    model.eval()
    running_val_loss = 0.0
    running_val_corrects = 0
    print("validation...")
    with torch.no_grad():
        for inputs, labels in val_dataloader:
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            loss = loss_f(outputs, labels)

            running_val_loss += loss.item() * inputs.size(0)
            running_val_corrects += torch.sum(preds == labels.data)

    epoch_train_loss = running_train_loss / len(train_dataset)
    epoch_train_acc = running_train_corrects.double() / len(train_dataset)
    epoch_val_loss = running_val_loss / len(val_dataset)
    epoch_val_acc = running_val_corrects.double() / len(val_dataset)

    print(f'Epoch {epoch+1}/{num_epochs} Train Loss: {epoch_train_loss:.4f} Train Acc: {epoch_train_acc:.4f} Val Loss: {epoch_val_loss:.4f} Val Acc: {epoch_val_acc:.4f}')

    # early_stopping needs the validation loss to check if it has decreased, 
    # and if it has, it will make a checkpoint of the current model
    early_stopping(epoch_val_loss, epoch_val_acc, model)