In [10]:
from temperature_scaling import ModelWithTemperature
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

import timm

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

import os
from tqdm import tqdm

%matplotlib inline

In [22]:
# Load the data, and split train into train and validation

batch_size = 512

mnist_train = datasets.MNIST('data', train=True, download=True, transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,)),
]), )

mnist_train, mnist_val = torch.utils.data.random_split(mnist_train, [50000, 10000])

# Split train dataset into labeled and unlabeled
mnist_train_labeled = torch.utils.data.Subset(mnist_train, np.where(np.array(mnist_train.indices) < 10000)[0])
mnist_train_unlabeled = torch.utils.data.Subset(mnist_train, np.where(np.array(mnist_train.indices) >= 10000)[0])

train_labeled_loader = torch.utils.data.DataLoader(mnist_train_labeled, batch_size=batch_size, shuffle=True)
train_unlabeled_loader = torch.utils.data.DataLoader(mnist_train_unlabeled, batch_size=batch_size, shuffle=True)

val_loader = torch.utils.data.DataLoader(mnist_val, batch_size=batch_size, shuffle=False)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('data', train=False, download=True, transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,)),
    ])),
    batch_size=batch_size, shuffle=False)

In [23]:
# Test the data loader

for data, target in train_labeled_loader:
    print("Data: ", data.size())
    print("Target: ", target.size())
    break

Data:  torch.Size([512, 1, 28, 28])
Target:  torch.Size([512])


In [24]:
model = nn.Sequential(
    nn.Flatten(),
    nn.Linear(28*28, 100),
    nn.ReLU(),
    nn.Linear(100, 10),
)

model = model.cuda()

In [28]:
# Train the model
def train(model, optimizer, train_loader, val_loader, epochs=10):
    model.train()
    for epoch in range(epochs):
        correct = 0
        total = 0
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.cuda(), target.cuda()
            optimizer.zero_grad()
            output = model(data)
            loss = F.cross_entropy(output, target)
            loss.backward()
            optimizer.step()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
            total += data.size(0)
        print(f"Train accuracy: {correct/total:.4f}")

        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for batch_idx, (data, target) in enumerate(val_loader):
                data, target = data.cuda(), target.cuda()
                output = model(data)
                pred = output.argmax(dim=1, keepdim=True)
                correct += pred.eq(target.view_as(pred)).sum().item()
                total += data.size(0)
        print(f"Validation accuracy: {correct/total:.4f}")
        model.train()

In [30]:
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
train(model, optimizer, train_labeled_loader, val_loader, epochs=10)

Train accuracy: 0.9931
Validation accuracy: 0.9363
Train accuracy: 0.9932
Validation accuracy: 0.9381
Train accuracy: 0.9937
Validation accuracy: 0.9363
Train accuracy: 0.9940
Validation accuracy: 0.9376
Train accuracy: 0.9937
Validation accuracy: 0.9368
Train accuracy: 0.9946
Validation accuracy: 0.9371
Train accuracy: 0.9944
Validation accuracy: 0.9364
Train accuracy: 0.9948
Validation accuracy: 0.9371
Train accuracy: 0.9945
Validation accuracy: 0.9376
Train accuracy: 0.9952
Validation accuracy: 0.9372


In [31]:
# Use model to generate pseudo-labels for unlabeled data

def generate_pseudo_labels(model, unlabeled_loader):
    model.eval()
    pseudo_labels = []
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(unlabeled_loader):
            data = data.cuda()
            output = model(data)
            pseudo_labels.append(output.argmax(dim=1))
    return torch.cat(pseudo_labels)

In [32]:
labels = generate_pseudo_labels(model, train_unlabeled_loader)

In [35]:
# Combine labeled and pseudo-labeled data, note that train_unlabeled_loader is Subset of mnist_train_unlabeled, so we need to access the dataset attribute of the loader to get the original dataset

mnist_train_unlabeled = train_unlabeled_loader.dataset
mnist_train_unlabeled.targets = labels

mnist_train_combined = torch.utils.data.ConcatDataset([mnist_train_labeled.dataset, mnist_train_unlabeled])

In [46]:
# Train the logistic regression model on the combined dataset, return the regression coefficients

def train_logistic_regression(model, optimizer, train_loader, test_loader, epochs=10):
    model.train()
    for epoch in range(epochs):
        correct = 0
        total = 0
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.cuda(), target.cuda()
            optimizer.zero_grad()
            output = model(data)
            loss = F.cross_entropy(output, target)
            loss.backward()
            optimizer.step()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
            total += data.size(0)
        print(f"Train accuracy: {correct/total:.4f}")
    
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(test_loader):
            data, target = data.cuda(), target.cuda()
            output = model(data)
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
            total += data.size(0)

    print(f"Test accuracy: {correct/total:.4f}")

In [48]:
# Train a logistic regression model on the labeled data

logistic_model_labeled = nn.Sequential(
    nn.Flatten(),
    nn.Linear(28*28, 10),
)

logistic_model_labeled = logistic_model_labeled.cuda()

optimizer = optim.SGD(logistic_model_labeled.parameters(), lr=0.01, momentum=0.9)
train_logistic_regression(logistic_model_labeled, optimizer, train_labeled_loader, test_loader, epochs=10)

Train accuracy: 0.6436
Train accuracy: 0.8721
Train accuracy: 0.8927
Train accuracy: 0.8997
Train accuracy: 0.9090
Train accuracy: 0.9125
Train accuracy: 0.9165
Train accuracy: 0.9208
Train accuracy: 0.9215
Train accuracy: 0.9235
Test accuracy: 0.9067


In [49]:
# Train a logistic regression model on the combined dataset

train_combined_loader = torch.utils.data.DataLoader(mnist_train_combined, batch_size=batch_size, shuffle=True)

logistic_model = nn.Sequential(
    nn.Flatten(),
    nn.Linear(28*28, 10),
)

logistic_model = logistic_model.cuda()

optimizer = optim.SGD(logistic_model.parameters(), lr=0.01, momentum=0.9)
train_logistic_regression(logistic_model, optimizer, train_combined_loader, test_loader, epochs=10)

Train accuracy: 0.8704
Train accuracy: 0.9135
Train accuracy: 0.9184
Train accuracy: 0.9217
Train accuracy: 0.9234
Train accuracy: 0.9251
Train accuracy: 0.9262
Train accuracy: 0.9270
Train accuracy: 0.9278
Train accuracy: 0.9283
Test accuracy: 0.9225


In [53]:
# Train a logistic regression model on the combined dataset with temperature scaling

logistic_model_ts = nn.Sequential(
    nn.Flatten(),
    nn.Linear(28*28, 10),
)

logistic_model_ts = logistic_model_ts.cuda()

optimizer = optim.SGD(logistic_model_ts.parameters(), lr=0.01, momentum=0.9)
train_logistic_regression(logistic_model_ts, optimizer, train_combined_loader, test_loader, epochs=10)

Train accuracy: 0.8707
Train accuracy: 0.9140
Train accuracy: 0.9181
Train accuracy: 0.9213
Train accuracy: 0.9232
Train accuracy: 0.9252
Train accuracy: 0.9263
Train accuracy: 0.9268
Train accuracy: 0.9279
Train accuracy: 0.9284
Test accuracy: 0.9211


In [54]:
# Temperature scaling

logistic_model_ts_temperature = ModelWithTemperature(logistic_model_ts)
logistic_model_ts_temperature.set_temperature(val_loader)

logistic_model_ts_temperature.temperature

Before temperature - NLL: 0.306, ECE: 0.010
Optimal temperature: 1.269
After temperature - NLL: 0.316, ECE: 0.048


Parameter containing:
tensor([1.2695], device='cuda:0', requires_grad=True)

In [55]:
# Test the temperature scaled model

correct = 0
total = 0

logistic_model_ts_temperature.eval()
with torch.no_grad():
    for batch_idx, (data, target) in enumerate(test_loader):
        data, target = data.cuda(), target.cuda()
        output = logistic_model_ts_temperature(data)
        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()
        total += data.size(0)

print(f"Test accuracy: {correct/total:.4f}")

Test accuracy: 0.9211
