In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import torch.optim as optim
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
from PIL import Image
from tqdm import tqdm

import mc_classification_learning_extension as mc

In [2]:
np.random.seed(0)

In [3]:
traindir = "data/mc_data/mc_training"
validdir = "data/mc_data/mc_validation"
testdir = "data/mc_data/mc_test"

transform  = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(
    [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
    )
])

train_set = datasets.ImageFolder(traindir, transform)
val_set = datasets.ImageFolder(validdir, transform)
test_set = datasets.ImageFolder(testdir, transform)

In [4]:
labeled_split = 0.1
num_data = len(train_set)
num_labeled = int(num_data * labeled_split)

# Get the set of all unique classes
classes = set(train_set.targets)
labeled_indices = []
class_labeled_indices = []
for class_idx in classes:
    class_indices = [i for i, target in enumerate(train_set.targets) if target == class_idx]
    subset_size = max(1, int(len(class_indices) * labeled_split))
    class_subset_indices = np.random.choice(class_indices, subset_size, replace=False)

    labeled_indices.extend(class_subset_indices)
    class_labeled_indices.append(class_subset_indices)

#labeled_indices = np.random.choice(num_data, num_labeled, replace=False)
unlabeled_indices = np.delete(np.arange(0, num_data, 1), labeled_indices)

In [5]:
labeled_subset = torch.utils.data.Subset(train_set, labeled_indices)
unlabeled_subset = torch.utils.data.Subset(train_set, unlabeled_indices)

In [6]:
for c in range(len(class_labeled_indices)):
    print(c)
    for i in class_labeled_indices[c]:
        print(train_set.imgs[i])
    print()

0
('data/mc_data/mc_training/Abyssinian/Abyssinian_145.jpg', 0)
('data/mc_data/mc_training/Abyssinian/Abyssinian_124.jpg', 0)
('data/mc_data/mc_training/Abyssinian/Abyssinian_127.jpg', 0)
('data/mc_data/mc_training/Abyssinian/Abyssinian_174.jpg', 0)
('data/mc_data/mc_training/Abyssinian/Abyssinian_140.jpg', 0)
('data/mc_data/mc_training/Abyssinian/Abyssinian_167.jpg', 0)
('data/mc_data/mc_training/Abyssinian/Abyssinian_160.jpg', 0)
('data/mc_data/mc_training/Abyssinian/Abyssinian_164.jpg', 0)

1
('data/mc_data/mc_training/Bengal/Bengal_103.jpg', 1)
('data/mc_data/mc_training/Bengal/Bengal_14.jpg', 1)
('data/mc_data/mc_training/Bengal/Bengal_156.jpg', 1)
('data/mc_data/mc_training/Bengal/Bengal_121.jpg', 1)
('data/mc_data/mc_training/Bengal/Bengal_104.jpg', 1)
('data/mc_data/mc_training/Bengal/Bengal_133.jpg', 1)
('data/mc_data/mc_training/Bengal/Bengal_110.jpg', 1)
('data/mc_data/mc_training/Bengal/Bengal_106.jpg', 1)

2
('data/mc_data/mc_training/Birman/Birman_170.jpg', 2)
('data/mc_d

In [7]:
dict = {}
for i, (_, targets) in enumerate(unlabeled_subset):
    if targets in dict:
        dict[targets] = dict[targets] + 1
    else:
        dict[targets] = 0
print(dict)

{0: 71, 1: 71, 2: 71, 3: 69, 4: 71, 5: 67, 6: 71, 7: 71, 8: 71, 9: 71, 10: 71, 11: 71, 12: 71, 13: 71, 14: 71, 15: 71, 16: 71, 17: 71, 18: 69, 19: 71, 20: 71, 21: 71, 22: 71, 23: 71, 24: 71, 25: 71, 26: 71, 27: 69, 28: 71, 29: 71, 30: 71, 31: 71, 32: 71, 33: 71, 34: 71, 35: 71, 36: 71}


In [8]:
labeled_loader = torch.utils.data.DataLoader(
    labeled_subset,
    batch_size=4,
    shuffle=True,
    num_workers=2
)

unlabeled_loader = torch.utils.data.DataLoader(
    unlabeled_subset,
    batch_size=4,
    shuffle=False,
    num_workers=2
)

classes = train_set.classes

In [9]:
val_loader = torch.utils.data.DataLoader(
        val_set,
        batch_size=4,
        shuffle=True,
        num_workers=2
    )
test_loader = torch.utils.data.DataLoader(
        test_set,
        batch_size=4,
        shuffle=True,
        num_workers=2
    )

In [10]:
model = mc.FineTunedResNet(37)
model.train(labeled_loader, 20)

73it [00:09,  7.34it/s]

Epoch 0: loss 254.95661568641663



73it [00:06, 11.11it/s]

Epoch 1: loss 200.7350857257843



73it [00:07, 10.27it/s]

Epoch 2: loss 160.76761496067047



73it [00:08,  8.51it/s]

Epoch 3: loss 129.74821364879608



73it [00:08,  8.53it/s]

Epoch 4: loss 101.3294734954834



73it [00:08,  8.81it/s]

Epoch 5: loss 75.83285987377167



73it [00:08,  8.71it/s]

Epoch 6: loss 60.526029735803604



73it [00:08,  8.94it/s]

Epoch 7: loss 53.94249480962753



73it [00:08,  9.04it/s]

Epoch 8: loss 46.16460299491882



73it [00:08,  8.93it/s]

Epoch 9: loss 38.46992293000221



73it [00:08,  8.49it/s]

Epoch 10: loss 33.395109444856644



73it [00:08,  8.14it/s]

Epoch 11: loss 28.44292240589857



73it [00:08,  8.71it/s]

Epoch 12: loss 25.48488650470972



73it [00:08,  8.99it/s]

Epoch 13: loss 27.545086286962032



73it [00:08,  8.66it/s]

Epoch 14: loss 25.008963234722614



73it [00:08,  8.70it/s]

Epoch 15: loss 23.90783540159464



73it [00:08,  8.65it/s]

Epoch 16: loss 22.134574007242918



73it [00:08,  8.75it/s]

Epoch 17: loss 19.633832290768623



73it [00:09,  7.93it/s]

Epoch 18: loss 18.41041848063469



73it [00:08,  8.64it/s]


Epoch 19: loss 18.829732317477465


In [11]:
model.validate(val_loader)

(array([[0.65      ],
        [0.45      ],
        [0.85      ],
        [0.84210526],
        [0.75      ],
        [0.66666667],
        [0.75      ],
        [0.8       ],
        [0.9       ],
        [0.6       ],
        [1.        ],
        [1.        ],
        [0.95      ],
        [0.75      ],
        [0.9       ],
        [0.8       ],
        [0.7       ],
        [0.65      ],
        [0.78947368],
        [0.85      ],
        [0.9       ],
        [0.9       ],
        [0.65      ],
        [1.        ],
        [1.        ],
        [0.55      ],
        [0.85      ],
        [0.94736842],
        [0.95      ],
        [0.95      ],
        [0.95      ],
        [1.        ],
        [0.7       ],
        [0.95      ],
        [0.8       ],
        [1.        ],
        [1.        ]]),
 0.8310626702997275)

In [12]:
model.validate(test_loader)

(array([[0.62244898],
        [0.69      ],
        [0.73      ],
        [0.57954545],
        [0.55      ],
        [0.71134021],
        [0.68      ],
        [0.71      ],
        [0.64      ],
        [0.6       ],
        [0.8       ],
        [0.85      ],
        [0.87      ],
        [0.47      ],
        [0.9       ],
        [0.98      ],
        [0.8989899 ],
        [0.61      ],
        [0.88      ],
        [0.88      ],
        [0.99      ],
        [0.88      ],
        [0.81      ],
        [0.98      ],
        [1.        ],
        [0.8       ],
        [0.83      ],
        [0.98      ],
        [0.76      ],
        [0.99      ],
        [0.97      ],
        [1.        ],
        [0.80808081],
        [0.98      ],
        [0.58426966],
        [0.97      ],
        [1.        ]]),
 0.8119378577269011)

In [13]:
model.validate(unlabeled_loader)

(array([[0.76388889],
        [0.65277778],
        [0.80555556],
        [0.77142857],
        [0.79166667],
        [0.79411765],
        [0.875     ],
        [0.84722222],
        [0.69444444],
        [0.58333333],
        [0.91666667],
        [0.98611111],
        [0.72222222],
        [0.76388889],
        [0.95833333],
        [0.88888889],
        [0.88888889],
        [0.68055556],
        [0.78571429],
        [0.84722222],
        [0.90277778],
        [0.91666667],
        [0.76388889],
        [0.91666667],
        [0.97222222],
        [0.83333333],
        [0.84722222],
        [0.97142857],
        [0.90277778],
        [0.98611111],
        [0.94444444],
        [0.95833333],
        [0.83333333],
        [0.91666667],
        [0.79166667],
        [0.98611111],
        [0.90277778]]),
 0.8477769404672193)

In [14]:
def generate_pseudo_labels(model, unlabeled_loader, confidence=0.9):
        pseudo_labels = []
        high_conf_indices = []
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        with torch.no_grad():
            for i, (x, _) in enumerate(unlabeled_loader):
                x = x.to(device)
                c = []
                output = model.model(x)
                for j in range(len(output)):
                    pred = output[j].argmax()
                    max_conf = torch.nn.functional.softmax(output[j], dim=0).max()
                    if max_conf >= confidence:
                        c.append(pred)
                        high_conf_indices.append(i * unlabeled_loader.batch_size + j)

                pseudo_labels.extend(c)
        return pseudo_labels, high_conf_indices
    

ps_labels, high_conf_indices = generate_pseudo_labels(model, unlabeled_loader, confidence=0.2)

In [15]:
ps_subset = torch.utils.data.Subset(unlabeled_subset, high_conf_indices)

In [17]:
correct = 0
for i, (x,y) in enumerate(ps_subset):
    if y == int(ps_labels[i]):
        correct += 1
correct/len(ps_subset)

0.8548761609907121

In [18]:
len(ps_subset)

2584

In [19]:
class PseudoLabeledDataset(torch.utils.data.Dataset):
    def __init__(self, labeled_ds, unlabeled_ds, pseudo_labels):
        self.labeled = labeled_ds
        self.unlabeled = unlabeled_ds
        self.pseudo_labels = pseudo_labels
    
    def __len__(self):
        return len(self.labeled) + len(self.unlabeled)
    
    def __getitem__(self, index):
        if index < len(self.labeled):
            labeled_data = self.labeled[index]
            return labeled_data[0], labeled_data[1]  
        else:
            pseudo_index = index - len(self.labeled)
            unlabeled_data = self.unlabeled[pseudo_index][0]
            pseudo_label = int(self.pseudo_labels[pseudo_index])
            return unlabeled_data, pseudo_label
    

In [20]:
ps_trainset = PseudoLabeledDataset(labeled_subset, ps_subset, ps_labels)
ps_loader = torch.utils.data.DataLoader(
        ps_trainset,
        batch_size=4,
        shuffle=True,
        num_workers=0
    )

In [21]:
ps_model = mc.FineTunedResNet(37)
ps_model.train(ps_loader, 20)

719it [00:48, 14.88it/s]


Epoch 0: loss 1588.021828353405


719it [00:57, 12.46it/s]


Epoch 1: loss 788.5089599788189


719it [01:01, 11.65it/s]


Epoch 2: loss 520.8223583698273


719it [00:52, 13.57it/s]


Epoch 3: loss 375.4866352006793


719it [00:48, 14.80it/s]


Epoch 4: loss 302.1360163837671


719it [00:56, 12.62it/s]


Epoch 5: loss 241.5358338141814


719it [01:20,  8.93it/s]


Epoch 6: loss 207.62592354789376


719it [01:25,  8.42it/s]


Epoch 7: loss 180.96079348400235


719it [01:02, 11.52it/s]


Epoch 8: loss 159.13859037472866


719it [00:55, 12.86it/s]


Epoch 9: loss 137.51581201481167


719it [01:09, 10.34it/s]


Epoch 10: loss 119.89766450040042


719it [00:57, 12.59it/s]


Epoch 11: loss 106.57983703853097


719it [00:55, 13.00it/s]


Epoch 12: loss 98.04784004285466


719it [01:18,  9.12it/s]


Epoch 13: loss 89.11583813483594


719it [01:08, 10.51it/s]


Epoch 14: loss 86.21440769633045


719it [01:05, 10.97it/s]


Epoch 15: loss 79.08621838741237


719it [01:13,  9.72it/s]


Epoch 16: loss 74.0737152909278


719it [01:13,  9.75it/s]


Epoch 17: loss 71.66249714780133


719it [01:05, 11.03it/s]


Epoch 18: loss 65.49375093332492


719it [01:24,  8.53it/s]


Epoch 19: loss 59.01214573858306


In [22]:
ps_model.validate(val_loader)

(array([[0.8       ],
        [0.65      ],
        [0.8       ],
        [0.89473684],
        [0.65      ],
        [0.83333333],
        [0.85      ],
        [0.9       ],
        [0.75      ],
        [0.5       ],
        [1.        ],
        [1.        ],
        [0.95      ],
        [0.55      ],
        [0.9       ],
        [0.75      ],
        [0.75      ],
        [0.4       ],
        [0.94736842],
        [1.        ],
        [0.9       ],
        [0.9       ],
        [0.7       ],
        [1.        ],
        [1.        ],
        [0.85      ],
        [0.9       ],
        [0.94736842],
        [0.9       ],
        [0.95      ],
        [0.95      ],
        [0.95      ],
        [0.7       ],
        [0.95      ],
        [0.9       ],
        [1.        ],
        [1.        ]]),
 0.8474114441416893)

In [23]:
ps_model.validate(test_loader)

(array([[0.70408163],
        [0.79      ],
        [0.62      ],
        [0.85227273],
        [0.72      ],
        [0.79381443],
        [0.74      ],
        [0.85      ],
        [0.56      ],
        [0.64      ],
        [0.91      ],
        [0.91      ],
        [0.88      ],
        [0.42      ],
        [0.95      ],
        [0.93      ],
        [0.93939394],
        [0.82      ],
        [0.93      ],
        [0.98      ],
        [1.        ],
        [0.96      ],
        [0.87      ],
        [0.98      ],
        [1.        ],
        [0.95      ],
        [0.94      ],
        [0.99      ],
        [0.92      ],
        [0.99      ],
        [0.98      ],
        [1.        ],
        [0.86868687],
        [0.95      ],
        [0.62921348],
        [0.95      ],
        [0.97      ]]),
 0.8626328699918234)