In [None]:
import numpy as np
import matplotlib.pyplot as plt
import datetime
import time
import scipy

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, ConcatDataset, Dataset


import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
import torchvision.datasets as datasets
from torchvision.transforms import ToTensor

import utils

In [None]:
#Mean: tensor([0 1309, 0.1309, 0.1309])
#Standard Deviation: tensor([0.2893, 0.2893, 0.2893])
train_dataset_mnist = datasets.MNIST(
    root = 'data',
    train = True,                         
    transform = transforms.Compose([transforms.Grayscale(3),
                                    transforms.Resize((32, 32)),
                                    transforms.ToTensor(),
                                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                                   ]),
    download = True,            
)
test_dataset_mnist = datasets.MNIST(
    root = 'data', 
    train = False, 
    transform = transforms.Compose([transforms.Grayscale(3),
                                    transforms.Resize((32, 32)),
                                    transforms.ToTensor(),
                                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                                   ]),
)


#Mean: tensor([0.4377, 0.4438, 0.4728])
#Standard Deviation: tensor([0.1980, 0.2010, 0.1970])
train_dataset_svhn = datasets.SVHN(
    root = 'data/SVHN',
    split = 'train',
    transform = transforms.Compose([transforms.Resize((32, 32)),
                                    ToTensor(),
                                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                                   ]),
    download = True,            
)
test_dataset_svhn = datasets.SVHN(
    root = 'data/SVHN', 
    split = 'test', 
    transform = transforms.Compose([transforms.Resize((32, 32)),
                                    ToTensor(),
                                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                                   ]),
    download = True,
)

## Creating a Concatened Dataset
train_dataset_combined = ConcatDataset([train_dataset_mnist,train_dataset_svhn])
test_dataset_combined = ConcatDataset([test_dataset_mnist,test_dataset_svhn])



# dataloaders
train_dataloader_mnist = torch.utils.data.DataLoader(train_dataset_mnist,batch_size=64, shuffle=False)
test_dataloader_mnist = torch.utils.data.DataLoader(test_dataset_mnist,batch_size=64, shuffle=False)

train_dataloader_svhn = torch.utils.data.DataLoader(train_dataset_svhn,batch_size=64, shuffle=False)
test_dataloader_svhn = torch.utils.data.DataLoader(test_dataset_svhn,batch_size=64, shuffle=False)

train_dataloader_combined = DataLoader(train_dataset_combined, batch_size=64, shuffle=False)
test_dataloader_combined = DataLoader(test_dataset_combined, batch_size=64, shuffle=False)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
model_mnist = utils.create_resnet_model()
model_svhn = utils.create_resnet_model()
model_combined = utils.create_resnet_model()

In [None]:
# utils.train(model_mnist, train_dataloader_mnist,  epochs=10, lr=0.001)
# torch.save(model_mnist.state_dict(), './pths/mnist32.pth')


# utils.train(model_svhn, train_dataloader_svhn, epochs=10, lr=0.001)
# torch.save(model_svhn.state_dict(), './pths/svhn32.pth')


# utils.train(model_combined, train_dataloader_combined, epochs=10, lr=0.001)
# torch.save(model_combined.state_dict(), './pths/combined32.pth')

In [None]:
model_mnist = utils.create_resnet_model('./pths/mnist32.pth')
model_svhn = utils.create_resnet_model('./pths/svhn32.pth')
model_combined = utils.create_resnet_model('./pths/combined32.pth')

In [None]:
# Test each model in itself dataset
test_loss, accuracy = utils.evaluate(model_mnist, test_dataloader_mnist)
print(f"model_mnist - Test Loss: {test_loss:.4f}, Accuracy: {accuracy:.2f}")

test_loss, accuracy = utils.evaluate(model_svhn, test_dataloader_svhn)
print(f"model_svhn - Test Loss: {test_loss:.4f}, Accuracy: {accuracy:.2f}")

test_loss, accuracy = utils.evaluate(model_combined, test_dataloader_combined)
print(f"model_combined - Test Loss: {test_loss:.4f}, Accuracy: {accuracy:.2f}")

In [None]:
# Test each model in combined dataset
test_loss, accuracy = utils.evaluate(model_mnist, test_dataloader_combined)
print(f"model_mnist - Test Loss: {test_loss:.4f}, Accuracy: {accuracy:.2f}")

# test_loss, accuracy = utils.evaluate(model_svhn, test_dataloader_combined)
# print(f"model_svhn - Test Loss: {test_loss:.4f}, Accuracy: {accuracy:.2f}")

# test_loss, accuracy = utils.evaluate(model_combined, test_dataloader_combined)
# print(f"model_combined - Test Loss: {test_loss:.4f}, Accuracy: {accuracy:.2f}")