In [None]:
import torch
import torchvision
import numpy as np
from copy import deepcopy
import argparse
from tqdm import tqdm
import matplotlib.pyplot as plt

In [None]:
### Hyperparameters
val_split = 0.1
unused_size = 0.99
lr = 0.0005
batch_size = 64
num_epochs = 100
data_iterations = 50
torch.manual_seed(42)

In [None]:
### Setup MNIST dataset
transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.5,), (0.5,))
])
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)

In [None]:
val_dataset = deepcopy(train_dataset)

In [None]:
train_size = int((1 - val_split) * len(train_dataset))
val_size = len(train_dataset) - train_size
indexes = torch.randperm(len(train_dataset)).tolist()
# Define validation set
indexes_val = indexes[train_size:]
val_dataset.targets = val_dataset.targets[indexes_val]
val_dataset.data = val_dataset.data[indexes_val]
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1024, shuffle=False)

In [None]:
# Define training set
indexes_train = indexes[:train_size]
train_dataset.targets = train_dataset.targets[indexes_train]
train_dataset.data = train_dataset.data[indexes_train]

In [None]:
# Split training data into labelled and unlabelled
unused_size = int(unused_size * len(train_dataset))
indexes_train = torch.randperm(len(train_dataset)).tolist()  # Redefine indexes_train
unused_dataset = deepcopy(train_dataset)
unused_dataset.targets = unused_dataset.targets[indexes_train[:unused_size]]
unused_dataset.data = unused_dataset.data[indexes_train[:unused_size]]
train_dataset.targets = train_dataset.targets[indexes_train[unused_size:]]
train_dataset.data = train_dataset.data[indexes_train[unused_size:]]
unused_dataset.targets = unused_dataset.targets
unused_dataset.data = unused_dataset.data
start_train_dataset = deepcopy(train_dataset)  # Save for baseline
start_unlabbelled_dataset = deepcopy(unused_dataset)  # Save for baseline

In [None]:
def transfer_unused_to_labeled(unused_dataset, train_dataset, indexes):
    # Convert indexes to boolean mask
    indexes = torch.tensor([i in indexes for i in range(len(unused_dataset.targets))])
    
    train_dataset.targets = torch.cat([train_dataset.targets, unused_dataset.targets[indexes]])
    train_dataset.data = torch.cat([train_dataset.data, unused_dataset.data[indexes]])
    unused_dataset.targets = unused_dataset.targets[~indexes]
    unused_dataset.data = unused_dataset.data[~indexes]

    return train_dataset, unused_dataset

In [None]:
def validate_model(model, val_loader, device):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return 100 * correct / total

In [None]:
# Setup model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = torchvision.models.resnet18(pretrained=False)
model.fc = torch.nn.Linear(model.fc.in_features, 10)
# Modify input layer to accept 1 channel
model.conv1 = torch.nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)

model_parameters = deepcopy(model.state_dict())
model = model.to(device)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

In [None]:
def train_model(model, train_loader, val_loader, criterion, optimizer, device, num_epochs=10, val_interval=1):
    accuracies = []
    for epoch in tqdm(range(num_epochs)):
        model.train()
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

        if (epoch + 1) % val_interval == 0:
            val_accuracy = validate_model(model, val_loader, device)
            accuracies.append(val_accuracy)
            print(f'Epoch {epoch + 1}, Accuracy: {val_accuracy:.2f}%')
    return accuracies

In [None]:
def add_data_iteration(model, train_dataset, unused_dataset, device, top_frac=0.01):
    # Use model to label all images in validation set
    model.eval()
    predictions = []
    unlabelled_loader = torch.utils.data.DataLoader(unused_dataset, batch_size=batch_size, shuffle=False, drop_last=False)

    with torch.no_grad():
        for images, _ in tqdm(unlabelled_loader):
            images = images.to(device)
            outputs = model(images).softmax(dim=1)
            predictions.extend(outputs.detach().cpu().numpy())

    predictions = torch.tensor(predictions)
    
    number_of_new_data_points = int(top_frac * len(predictions))
    indices = range(number_of_new_data_points)
    print(f"Adding {len(indices)} images to training set")
    train_dataset, unlabelled_dataset = transfer_unused_to_labeled(unused_dataset, train_dataset, indices)
    
    return train_dataset, unlabelled_dataset

In [None]:
datapoint_list = []
accuracy_list = []

for i in range(data_iterations):
    print(i)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
    model.load_state_dict(model_parameters)  # Important to reset the model each time
    accuracies = train_model(model, train_loader, val_loader, criterion, optimizer, device, num_epochs=num_epochs, val_interval=10)
    datapoint_list.append(len(train_dataset))
    accuracy_list.append(accuracies)
    if i < data_iterations - 1:
        train_dataset, unused_dataset = add_data_iteration(model, train_dataset, unused_dataset, device, top_frac=0.001)

In [None]:
# Plot the accuracy
datapoints = np.array(datapoint_list)
accuracies = np.array(accuracy_list).max(-1)
plt.figure(figsize=(10, 5))
plt.plot(datapoints, accuracies)

plt.xlabel('Datapoints')
plt.ylabel('Accuracy')
plt.title('Accuracy of resnet18 for increasing number of traningset size')
plt.legend()

plt.tight_layout()
plt.savefig('figs/1_Increasing_traningset_size.png')
plt.show()