In [2]:
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm # Displays a progress bar
import pandas as pd
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import Dataset, Subset, DataLoader, random_split

In [3]:
# Load the dataset and train, val, test splits
print("Loading datasets...")
dataset_path = "C:/Users/Admin/Desktop/cse803_hw5"

MNIST_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.1307], [0.3081])
])
MNIST_train = datasets.MNIST(
    dataset_path,
    download=True,
    train=True,
    transform=MNIST_transform
)
MNIST_test = datasets.MNIST(
    dataset_path,
    download=True,
    train = False,
    transform=MNIST_transform
)

FASHION_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.2859], [0.3530])
])
FASHION_train = datasets.FashionMNIST(
    dataset_path,
    download=True,
    train=True,
    transform=MNIST_transform
)
FASHION_test = datasets.FashionMNIST(
    dataset_path,
    download=True,
    train=False,
    transform=FASHION_transform
)


Loading datasets...
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1000)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to C:/Users/Admin/Desktop/cse803_hw5\MNIST\raw\train-images-idx3-ubyte.gz


100%|████████████████████████████████████████| 9912422/9912422 [00:02<00:00, 4386004.32it/s]


Extracting C:/Users/Admin/Desktop/cse803_hw5\MNIST\raw\train-images-idx3-ubyte.gz to C:/Users/Admin/Desktop/cse803_hw5\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1000)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to C:/Users/Admin/Desktop/cse803_hw5\MNIST\raw\train-labels-idx1-ubyte.gz


100%|█████████████████████████████████████████████| 28881/28881 [00:00<00:00, 925045.96it/s]

Extracting C:/Users/Admin/Desktop/cse803_hw5\MNIST\raw\train-labels-idx1-ubyte.gz to C:/Users/Admin/Desktop/cse803_hw5\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz





Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1000)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to C:/Users/Admin/Desktop/cse803_hw5\MNIST\raw\t10k-images-idx3-ubyte.gz


100%|████████████████████████████████████████| 1648877/1648877 [00:00<00:00, 6664769.54it/s]


Extracting C:/Users/Admin/Desktop/cse803_hw5\MNIST\raw\t10k-images-idx3-ubyte.gz to C:/Users/Admin/Desktop/cse803_hw5\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1000)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to C:/Users/Admin/Desktop/cse803_hw5\MNIST\raw\t10k-labels-idx1-ubyte.gz


100%|███████████████████████████████████████████████████████████| 4542/4542 [00:00<?, ?it/s]

Extracting C:/Users/Admin/Desktop/cse803_hw5\MNIST\raw\t10k-labels-idx1-ubyte.gz to C:/Users/Admin/Desktop/cse803_hw5\MNIST\raw






In [6]:
"""
Data Loaders.
"""
class GridDataset(Dataset):
    def __init__(self, MNIST_dataset, FASHION_dataset): # pass in dataset
        assert len(MNIST_dataset) == len(FASHION_dataset)
        self.MNIST_dataset, self.FASHION_dataset = MNIST_dataset, FASHION_dataset
        self.targets = FASHION_dataset.targets
        torch.manual_seed(442) # Fix random seed for reproducibility
        N = len(MNIST_dataset)
        self.randpos = torch.randint(low=0,high=4,size=(N,)) # position of the FASHION-MNIST image
        self.randidx = torch.randint(low=0,high=N,size=(N,3)) # indices of MNIST images
    
    def __len__(self):
        return len(self.MNIST_dataset)
    
    def __getitem__(self,idx): # Get one Fashion-MNIST image and three MNIST images to make a new image
        idx1, idx2, idx3 = self.randidx[idx]
        x = self.randpos[idx]%2
        y = self.randpos[idx]//2
        p1 = self.FASHION_dataset.__getitem__(idx)[0]
        p2 = self.MNIST_dataset.__getitem__(idx1)[0]
        p3 = self.MNIST_dataset.__getitem__(idx2)[0]
        p4 = self.MNIST_dataset.__getitem__(idx3)[0]
        combo = torch.cat((torch.cat((p1,p2),2),torch.cat((p3,p4),2)),1)
        combo = torch.roll(combo, (x*28,y*28), dims=(0,1))
        return (combo,self.targets[idx])

trainset = GridDataset(MNIST_train, FASHION_train)
testset = GridDataset(MNIST_test, FASHION_test)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)
testloader = DataLoader(testset, batch_size=64, shuffle=True)

In [5]:
"""
Network class.
"""
class Network(nn.Module):
    def __init__(self):
        super().__init__()
        # TODO: Design your own base module, define layers here
        self.base = nn.Sequential(
            nn.Conv2d(1,32,5,1,2),
            nn.ReLU(),
        )
        out_channel = 32 # TODO: Put the output channel number of your base module here
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(out_channel,10)
        self.conv = nn.Conv2d(out_channel,10,1) # 1x1 conv layer (substitutes fc)

    def transfer(self): # Copy weights of fc layer into 1x1 conv layer
        self.conv.weight = nn.Parameter(self.fc.weight.unsqueeze(2).unsqueeze(3))
        self.conv.bias = nn.Parameter(self.fc.bias)

    def visualize(self,x):
        x = self.base(x)
        x = self.conv(x)
        return x
        
    def forward(self,x):
        x = self.base(x)
        x = self.avgpool(x)
        x = x.view(x.size(0),-1)
        x = self.fc(x)
        return x

In [7]:
"""
Hyperparameters.
"""
# configure device
device = "cuda" if torch.cuda.is_available() else "cpu"

# init model
model = Network().to(device)

# specify the loss layer
criterion = nn.CrossEntropyLoss()

# TODO: Modify the line below, experiment with different optimizers and parameters (such as learning rate)
optimizer = optim.Adam(
    model.parameters(),
    lr=0.001,
    weight_decay=1e-4
)

# TODO: choose an appropriate number of training epochs
num_epoch = 10

In [30]:
"""
Train & evaluation functions.
"""
def train(model, train_loader, val_loader, num_epoch = 10): # Train the model
    print("Start training...")
    train_losses = []
    val_losses = []
    
    for i in range(num_epoch):
        # Set the model to training mode
        model.train()
        running_loss = []
        for batch, label in tqdm(train_loader):
            # format data
            batch = batch.to(device)
            label = label.to(device)
            
            # Clear gradients from the previous iteration
            optimizer.zero_grad()
            
            # This will call Network.forward() that you implement
            pred = model(batch)
            
            # Calculate the training loss
            loss = criterion(pred, label)
            running_loss.append(loss.item())
            
            # Backprop gradients to all tensors in the network
            loss.backward()
            
            # Update trainable weights
            optimizer.step()
        
        # training loss
        train_loss = np.mean(running_loss)
        train_losses.append(train_loss)
        
        # validation loss
        _, val_loss = evaluate(model, val_loader)
        val_losses.append(val_loss)
        
        # report epoch results
        print(f"Epoch {i+1}: train_loss={train_loss}, val_loss={val_loss}") # Print the average losses for this epoch
    
    # finished
    print("Done!")
    return train_losses, val_losses

def evaluate(model, val_loader): # Evaluate accuracy on validation / test set
    model.eval() # Set the model to evaluation mode
    running_loss = []
    correct = 0
    with torch.no_grad(): # Do not calculate grident to speed up computation
        for batch, label in tqdm(val_loader):
            # format data
            batch = batch.to(device)
            label = label.to(device)
            
            # make predictions
            pred = model(batch)
            
            # Calculate the validation loss
            loss = criterion(pred, label)
            running_loss.append(loss.item())
            
            # calculate batch accuracy
            correct += (torch.argmax(pred,dim=1)==label).sum().item()
    
    # averaged accuracy
    acc = correct / len(val_loader.dataset)
    
    # validation loss
    val_loss = np.mean(running_loss)
    
    # finished
    print("Evaluation accuracy: {}".format(acc))
    return acc, val_loss

In [None]:
"""
Train and evaluate model.
"""
# train
train_losses, val_losses = train(
    model,
    trainloader,
    valloader,
    num_epoch
)

print("Evaluate on test set")
test_acc, test_loss = evaluate(
    model,
    testloader
)

model.transfer() # Copy the weights from fc layer to 1x1 conv layer

# TODO: Choose a correctly classified image and visualize it

In [None]:
"""
Analyze training & evaluation results.
"""
results = []
for i, (t_loss, v_loss) in enumerate(zip(train_losses, val_losses)):
    results.append({
        'epoch': i,
        'training_loss': t_loss,
        'validation_loss': v_loss,
    })

results_df = pd.DataFrame.from_records(results).set_index('epoch')
print(results_df)

# plot figure
results_df.plot(
    xlabel="Epoch",
    ylabel="Loss",
    grid=True,
)
plt.title("Q2: Training Loss vs Epoch", fontsize=10)
plt.savefig(f"./figures/{"q2_losses"}.png")