In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import torch.optim as optim
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
from PIL import Image
from tqdm import tqdm

#import mc_classification_learning_extension as mc

In [2]:
np.random.seed(0)

In [3]:
class PseudoLabeledDataset(torch.utils.data.Dataset):
    def __init__(self, labeled_ds, unlabeled_ds, pseudo_labels):
        self.labeled = labeled_ds
        self.unlabeled = unlabeled_ds
        self.pseudo_labels = pseudo_labels
    
    def __len__(self):
        return len(self.labeled) + len(self.unlabeled)
    
    def __getitem__(self, index):
        if index < len(self.labeled):
            labeled_data = self.labeled[index]
            return labeled_data[0], labeled_data[1]  
        else:
            pseudo_index = index - len(self.labeled)
            unlabeled_data = self.unlabeled[pseudo_index][0]
            pseudo_label = int(self.pseudo_labels[pseudo_index])
            return unlabeled_data, pseudo_label
    

In [4]:
class FineTunedResNet():
    def __init__(self, num_classes, dropout_rate=0, criterion=nn.CrossEntropyLoss(), optimizer=optim.Adam, lr=0.0005, pretrained_lr = 0.000005, betas=(0.9, 0.999), eps=1e-8, weight_decay=0) -> None:
        self.is_finetuned = False
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.n_classes = num_classes
        self.model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)

        # initially, all layers are frozen
        for param in self.model.parameters():
            param.requires_grad = False
        # only the final layer is unfrozen
        for param in self.model.fc.parameters():
            param.requires_grad = True

        num_ftrs = self.model.fc.in_features

        # Replace the final layer with a Sequential containing the Dropout and Linear layers
        self.model.fc = nn.Sequential(
            nn.Dropout(dropout_rate),
            nn.Linear(num_ftrs, self.n_classes)
        )

        self.model = self.model.to(self.device)

        self.criterion = criterion

        # define an optimizer for each layer
        self.optimizers = []
        for i in range(4):
            params = [{'params': getattr(self.model, f'layer{i+1}').parameters(), 'lr': pretrained_lr}]
            self.optimizers.append(optimizer(params, betas=betas, eps=eps, weight_decay=weight_decay))
        self.optimizers.append(optimizer([{'params': self.model.fc.parameters(), 'lr': lr}], betas=betas, eps=eps, weight_decay=weight_decay))
        
        self.schedulers = []
        # define a scheduler for each optimizer
        for opt in self.optimizers:
            self.schedulers.append(torch.optim.lr_scheduler.StepLR(opt, step_size=3, gamma=0.8))

    def alpha_weight(self, step, num_epochs, af=2):
        if step < num_epochs/2:
            return 0.0
        #elif step > num_epochs * 3/4:
        #    return af
        else:
            return ((step-num_epochs/2) / (num_epochs*3/4 - num_epochs/2))# * af

    def train(self, train_loader, n_batches, save_model=True, safepath="", pseudolabeling={"on":False, "unlabeled":[], "confidence":0.9}):
        if pseudolabeling["on"] and not pseudolabeling["unlabeled"]:
                raise Exception("Unlabeled data for pseudo labeling is missing.")
                
        self.is_finetuned = True
        
        #self.model.train()  # Set the model to training mode
        
        #if save_model and not safepath:
        #    safepath = f"ft_model_{self.n_classes}_{num_epochs}.pt"
        
        epoch = 0
        unlabeled_loader = pseudolabeling["unlabeled"]
        num_epochs = len(unlabeled_loader)/n_batches
        for j, (x_un, _) in enumerate(unlabeled_loader):
            x_un = x_un.to(self.device)
            self.model.eval()
            pred = self.model(x_un)
            _, pseudo_label = torch.max(pred, 1)
            self.model.train()
            
            out = self.model(x_un)
            u_loss = self.alpha_weight(epoch, num_epochs) * self.criterion(out, pseudo_label)
            
            u_loss.backward()
            for optimizer in self.optimizers:
                    optimizer.step()
            
            #for epoch in range(num_epochs):
        
            if j % n_batches == 0:

                # decide which layers to unfreeze in this epoch
                if epoch < len(self.optimizers):
                    for param in self.optimizers[epoch].param_groups[0]['params']:
                        param.requires_grad = True

                r_loss = 0.0
                data_loader = train_loader

                for i, data in tqdm(enumerate(data_loader, 0)):
                    x, y = data
                    x = x.to(self.device)
                    y = y.to(self.device)

                    # clear and update all optimizers
                    for optimizer in self.optimizers:
                        optimizer.zero_grad()

                    out = self.model(x)
                    loss = self.criterion(out, y)

                    loss.backward()
                    for optimizer in self.optimizers:
                        optimizer.step()

                    r_loss += loss.item()

                # update all schedulers
                for scheduler in self.schedulers:
                    scheduler.step()

                print(f"Epoch {epoch}: loss {r_loss}")
                epoch += 1
                if epoch >= num_epochs:
                    break

        #if save_model:
        #    torch.save(self.model.state_dict(), safepath)


    def validate(self, val_loader):
        if not self.is_finetuned:
            raise Exception("The model has not been fine-tuned. Call model.train first to fine-tune the model.")
            
        self.model.eval()  # Set the model to evaluation mode
        
        correct = [0] * self.n_classes
        total = [0] * self.n_classes
        with torch.no_grad():
            for i, data in enumerate(val_loader, 0):
                x, y = data
                x = x.to(self.device)
                y = y.to(self.device)
                c = []
                output = self.model(x)
                for out in output:
                    c.append(out.argmax())
                for i in range(len(y)):
                    label = y[i]
                    if label == c[i]:
                        correct[label] += 1
                    total[label] += 1
        return self.get_accuracy(correct, total)

    def test_single(self, test_img):
        if not self.is_finetuned:
            raise Exception("The model has not been fine-tuned. Call model.train first to fine-tune the model.")
        self.model.eval()
        img_tensor = test_img.unsqueeze(0)
        img_tensor = img_tensor.to(self.device)

        out = self.model(img_tensor)
        pred_class = out.argmax().item()

        return pred_class
    
    def get_accuracy(self, correct, total):
        acc_class = np.zeros((self.n_classes, 1))
        for i in range(self.n_classes):
            acc_class[i] = correct[i]/total[i]
        acc = sum(correct)/sum(total)
        return acc_class, acc
    
    def generate_pseudo_labels(self, unlabeled_loader, confidence):
            pseudo_labels = []
            high_conf_indices = []
            with torch.no_grad():
                for i, (x, _) in enumerate(unlabeled_loader):
                    x = x.to(self.device)
                    c = []
                    output = self.model(x)
                    for j in range(len(output)):
                        pred = output[j].argmax()
                        max_conf = torch.nn.functional.softmax(output[j], dim=0).max()
                        if max_conf >= confidence:
                            c.append(pred)
                            high_conf_indices.append(i * unlabeled_loader.batch_size + j)

                    pseudo_labels.extend(c)
            return pseudo_labels, high_conf_indices

In [5]:
traindir = "data/mc_data/mc_training"
validdir = "data/mc_data/mc_validation"
testdir = "data/mc_data/mc_test"

transform  = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(
    [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
    )
])

train_set = datasets.ImageFolder(traindir, transform)
val_set = datasets.ImageFolder(validdir, transform)
test_set = datasets.ImageFolder(testdir, transform)

In [13]:
labeled_split = 0.1
num_data = len(train_set)
num_labeled = int(num_data * labeled_split)

# Get the set of all unique classes
classes = set(train_set.targets)
labeled_indices = []
class_labeled_indices = []
for class_idx in classes:
    class_indices = [i for i, target in enumerate(train_set.targets) if target == class_idx]
    subset_size = max(1, int(len(class_indices) * labeled_split))
    class_subset_indices = np.random.choice(class_indices, subset_size, replace=False)

    labeled_indices.extend(class_subset_indices)
    class_labeled_indices.append(class_subset_indices)

#labeled_indices = np.random.choice(num_data, num_labeled, replace=False)
unlabeled_indices = np.delete(np.arange(0, num_data, 1), labeled_indices)

In [14]:
labeled_subset = torch.utils.data.Subset(train_set, labeled_indices)
unlabeled_subset = torch.utils.data.Subset(train_set, unlabeled_indices)

In [15]:
labeled_loader = torch.utils.data.DataLoader(
    labeled_subset,
    batch_size=4,
    shuffle=True,
    num_workers=2
)

unlabeled_loader = torch.utils.data.DataLoader(
    unlabeled_subset,
    batch_size=4,
    shuffle=False,
    num_workers=2
)

classes = train_set.classes

In [9]:
val_loader = torch.utils.data.DataLoader(
        val_set,
        batch_size=4,
        shuffle=True,
        num_workers=2
    )
test_loader = torch.utils.data.DataLoader(
        test_set,
        batch_size=4,
        shuffle=True,
        num_workers=2
    )

In [16]:
ps_model = FineTunedResNet(37)
ps_settings = {"on": True, 
               "unlabeled": unlabeled_loader}
# middle number != num_epochs, 
# determines after how many batches of unlabeled data one epoch of training is done on the labeled set
ps_model.train(labeled_loader, 5, pseudolabeling=ps_settings)

73it [00:02, 25.42it/s]


Epoch 0: loss 259.48752093315125


73it [00:03, 23.86it/s]


Epoch 1: loss 198.3794651031494


73it [00:03, 20.89it/s]


Epoch 2: loss 162.03794312477112


73it [00:04, 17.88it/s]


Epoch 3: loss 125.6178650856018


73it [00:04, 17.89it/s]


Epoch 4: loss 93.71319144964218


73it [00:04, 17.86it/s]


Epoch 5: loss 73.12363365292549


73it [00:04, 17.75it/s]


Epoch 6: loss 59.71476849913597


73it [00:04, 17.65it/s]


Epoch 7: loss 45.780975088477135


73it [00:04, 17.58it/s]


Epoch 8: loss 43.347234696149826


73it [00:04, 17.47it/s]


Epoch 9: loss 33.46270537376404


73it [00:05, 14.07it/s]


Epoch 10: loss 28.8371010273695


73it [00:04, 17.69it/s]


Epoch 11: loss 26.96628039330244


73it [00:04, 17.85it/s]


Epoch 12: loss 25.15172991901636


73it [00:04, 17.79it/s]


Epoch 13: loss 22.247276082634926


73it [00:04, 17.82it/s]


Epoch 14: loss 22.183743327856064


73it [00:04, 17.79it/s]


Epoch 15: loss 18.51471944153309


73it [00:04, 17.90it/s]


Epoch 16: loss 16.471517138183117


73it [00:04, 17.99it/s]


Epoch 17: loss 14.568467415869236


73it [00:04, 18.04it/s]


Epoch 18: loss 13.813299026340246


73it [00:04, 17.98it/s]


Epoch 19: loss 13.729018870741129


73it [00:04, 17.93it/s]


Epoch 20: loss 12.505056705325842


73it [00:04, 17.93it/s]


Epoch 21: loss 15.445683157071471


73it [00:04, 17.74it/s]


Epoch 22: loss 14.332721062004566


73it [00:04, 15.31it/s]


Epoch 23: loss 13.002751793712378


73it [00:04, 16.92it/s]


Epoch 24: loss 11.588381756097078


73it [00:04, 17.82it/s]


Epoch 25: loss 11.064884703606367


73it [00:04, 17.82it/s]


Epoch 26: loss 12.136933200061321


73it [00:04, 17.88it/s]


Epoch 27: loss 10.346416711807251


73it [00:04, 17.91it/s]


Epoch 28: loss 12.921646237373352


73it [00:04, 17.71it/s]


Epoch 29: loss 9.957668855786324


73it [00:04, 17.83it/s]


Epoch 30: loss 9.594204626977444


73it [00:04, 17.92it/s]


Epoch 31: loss 9.388128474354744


73it [00:04, 17.95it/s]


Epoch 32: loss 10.161199793219566


73it [00:04, 17.87it/s]


Epoch 33: loss 9.462566936388612


73it [00:04, 17.93it/s]


Epoch 34: loss 8.508437674492598


73it [00:04, 17.94it/s]


Epoch 35: loss 9.939454862847924


73it [00:04, 16.17it/s]


Epoch 36: loss 8.060041582211852


73it [00:04, 15.81it/s]


Epoch 37: loss 10.733338985592127


73it [00:04, 17.88it/s]


Epoch 38: loss 8.751976758241653


73it [00:04, 17.90it/s]


Epoch 39: loss 10.949629062786698


73it [00:04, 17.86it/s]


Epoch 40: loss 9.498484645038843


73it [00:04, 17.86it/s]


Epoch 41: loss 8.909307695925236


73it [00:04, 17.79it/s]


Epoch 42: loss 8.976699465885758


73it [00:04, 17.82it/s]


Epoch 43: loss 9.927986508235335


73it [00:04, 17.91it/s]


Epoch 44: loss 8.944966897368431


73it [00:04, 17.85it/s]


Epoch 45: loss 7.144225252792239


73it [00:04, 17.88it/s]


Epoch 46: loss 9.054224649444222


73it [00:04, 17.88it/s]


Epoch 47: loss 6.947392979636788


73it [00:04, 17.87it/s]


Epoch 48: loss 8.794139862060547


73it [00:04, 17.78it/s]


Epoch 49: loss 7.7606877237558365


73it [00:05, 14.36it/s]


Epoch 50: loss 9.271527277305722


73it [00:04, 17.84it/s]


Epoch 51: loss 9.239990197122097


73it [00:04, 17.95it/s]


Epoch 52: loss 7.987172156572342


73it [00:04, 17.91it/s]


Epoch 53: loss 6.711514504626393


73it [00:04, 17.91it/s]


Epoch 54: loss 8.96571483835578


73it [00:04, 17.89it/s]


Epoch 55: loss 7.415797828696668


73it [00:04, 17.83it/s]


Epoch 56: loss 7.768172807991505


73it [00:04, 17.82it/s]


Epoch 57: loss 8.05074910260737


73it [00:04, 17.95it/s]


Epoch 58: loss 6.8721074890345335


73it [00:04, 17.87it/s]


Epoch 59: loss 8.150726228952408


73it [00:04, 17.92it/s]


Epoch 60: loss 6.600090444087982


73it [00:04, 17.89it/s]


Epoch 61: loss 7.558829626999795


73it [00:04, 17.76it/s]


Epoch 62: loss 8.831721441820264


73it [00:05, 14.08it/s]


Epoch 63: loss 7.474516283720732


73it [00:04, 17.81it/s]


Epoch 64: loss 7.763640977442265


73it [00:04, 17.82it/s]


Epoch 65: loss 8.957270991057158


73it [00:04, 17.90it/s]


Epoch 66: loss 5.216547150164843


73it [00:04, 17.86it/s]


Epoch 67: loss 8.43885584268719


73it [00:04, 17.91it/s]


Epoch 68: loss 8.365812207572162


73it [00:04, 17.74it/s]


Epoch 69: loss 7.838579750619829


73it [00:04, 17.88it/s]


Epoch 70: loss 8.059890802018344


73it [00:04, 17.88it/s]


Epoch 71: loss 7.520674891769886


73it [00:04, 17.81it/s]


Epoch 72: loss 6.123321617953479


73it [00:04, 17.78it/s]


Epoch 73: loss 6.258097080513835


73it [00:04, 17.89it/s]


Epoch 74: loss 7.142702309414744


73it [00:04, 17.84it/s]


Epoch 75: loss 7.4114334266632795


73it [00:04, 15.50it/s]


Epoch 76: loss 9.617822005413473


73it [00:04, 16.44it/s]


Epoch 77: loss 6.9263916015625


73it [00:04, 17.92it/s]


Epoch 78: loss 7.089059522375464


73it [00:04, 17.74it/s]


Epoch 79: loss 7.1777111906558275


73it [00:04, 17.87it/s]


Epoch 80: loss 8.481403497979045


73it [00:04, 17.82it/s]


Epoch 81: loss 7.668276458978653


73it [00:04, 17.76it/s]


Epoch 82: loss 8.43086860049516


73it [00:04, 17.79it/s]


Epoch 83: loss 6.385275202803314


73it [00:04, 17.86it/s]


Epoch 84: loss 10.312651264481246


73it [00:04, 17.84it/s]


Epoch 85: loss 7.199020277708769


73it [00:04, 17.91it/s]


Epoch 86: loss 7.92336561717093


73it [00:04, 17.88it/s]


Epoch 87: loss 10.652787069790065


73it [00:04, 17.92it/s]


Epoch 88: loss 8.034868601709604


73it [00:04, 16.54it/s]


Epoch 89: loss 8.139337953180075


73it [00:04, 15.78it/s]


Epoch 90: loss 7.827641731128097


73it [00:04, 17.87it/s]


Epoch 91: loss 7.947548542171717


73it [00:04, 17.81it/s]


Epoch 92: loss 6.5782490419223905


73it [00:04, 17.85it/s]


Epoch 93: loss 7.341130951419473


73it [00:04, 17.83it/s]


Epoch 94: loss 8.344713728874922


73it [00:04, 17.83it/s]


Epoch 95: loss 8.768610265105963


73it [00:04, 17.76it/s]


Epoch 96: loss 7.83632255345583


73it [00:04, 17.83it/s]


Epoch 97: loss 9.270234359428287


73it [00:04, 17.77it/s]


Epoch 98: loss 7.858106744475663


73it [00:04, 17.92it/s]


Epoch 99: loss 7.287183251231909


73it [00:04, 17.89it/s]


Epoch 100: loss 6.515707416459918


73it [00:04, 17.87it/s]


Epoch 101: loss 8.38829866424203


73it [00:04, 17.87it/s]


Epoch 102: loss 6.7500598123297095


73it [00:05, 14.13it/s]


Epoch 103: loss 7.161962736397982


73it [00:04, 17.86it/s]


Epoch 104: loss 6.945156129077077


73it [00:04, 17.86it/s]


Epoch 105: loss 7.556191432289779


73it [00:04, 17.89it/s]


Epoch 106: loss 7.687712199985981


73it [00:04, 17.84it/s]


Epoch 107: loss 9.935396689921618


73it [00:04, 17.92it/s]


Epoch 108: loss 7.173551447689533


73it [00:04, 17.67it/s]


Epoch 109: loss 9.175725182518363


73it [00:04, 17.85it/s]


Epoch 110: loss 8.563073688186705


73it [00:04, 17.86it/s]


Epoch 111: loss 7.2671637032181025


73it [00:04, 17.85it/s]


Epoch 112: loss 7.008296033367515


73it [00:04, 17.82it/s]


Epoch 113: loss 7.747476814314723


73it [00:04, 17.94it/s]


Epoch 114: loss 8.165400149300694


73it [00:04, 17.84it/s]


Epoch 115: loss 7.078112091869116


73it [00:05, 14.17it/s]


Epoch 116: loss 7.411943709477782


73it [00:04, 17.82it/s]


Epoch 117: loss 6.9980242643505335


73it [00:04, 17.84it/s]


Epoch 118: loss 8.023639580234885


73it [00:04, 17.86it/s]


Epoch 119: loss 7.243184294551611


73it [00:04, 17.87it/s]


Epoch 120: loss 8.740171084180474


73it [00:04, 17.85it/s]


Epoch 121: loss 8.59961883816868


73it [00:04, 17.75it/s]


Epoch 122: loss 7.1280162669718266


73it [00:04, 17.90it/s]


Epoch 123: loss 7.2376061556860805


73it [00:04, 17.89it/s]


Epoch 124: loss 8.641208034008741


73it [00:04, 17.88it/s]


Epoch 125: loss 7.230211474001408


73it [00:04, 17.96it/s]


Epoch 126: loss 8.391946841962636


73it [00:04, 17.90it/s]


Epoch 127: loss 7.523797178640962


73it [00:04, 17.83it/s]


Epoch 128: loss 8.563752168789506


73it [00:04, 15.46it/s]


Epoch 129: loss 7.81800137180835


73it [00:04, 17.03it/s]


Epoch 130: loss 8.005762005224824


73it [00:04, 17.88it/s]


Epoch 131: loss 7.1398048512637615


73it [00:04, 17.86it/s]


Epoch 132: loss 7.655453495681286


In [17]:
ps_model.validate(val_loader)

(array([[0.75      ],
        [0.55      ],
        [0.7       ],
        [0.89473684],
        [0.8       ],
        [0.77777778],
        [0.8       ],
        [0.6       ],
        [0.8       ],
        [0.65      ],
        [0.78947368],
        [1.        ],
        [0.8       ],
        [0.8       ],
        [0.9       ],
        [0.75      ],
        [0.7       ],
        [0.55      ],
        [0.89473684],
        [1.        ],
        [0.85      ],
        [1.        ],
        [0.95      ],
        [1.        ],
        [1.        ],
        [0.9       ],
        [0.9       ],
        [0.89473684],
        [0.9       ],
        [0.9       ],
        [0.95      ],
        [0.85      ],
        [0.75      ],
        [1.        ],
        [0.6       ],
        [0.95      ],
        [0.95      ]]),
 0.8337874659400545)

In [18]:
ps_model.validate(test_loader)

(array([[0.6122449 ],
        [0.72      ],
        [0.7       ],
        [0.72727273],
        [0.61      ],
        [0.71134021],
        [0.77      ],
        [0.57      ],
        [0.57      ],
        [0.61      ],
        [0.66      ],
        [0.74      ],
        [0.73      ],
        [0.56      ],
        [0.97      ],
        [0.79      ],
        [0.94949495],
        [0.68      ],
        [0.88      ],
        [0.88      ],
        [0.99      ],
        [0.95      ],
        [0.96      ],
        [1.        ],
        [0.98989899],
        [0.94      ],
        [0.91      ],
        [0.86      ],
        [0.79      ],
        [0.96      ],
        [0.97      ],
        [0.98      ],
        [0.92929293],
        [0.95      ],
        [0.61797753],
        [0.94      ],
        [0.92      ]]),
 0.8143908421913328)