In [1]:
import torch
import torchvision
from torchvision import models
import torchvision.transforms as transforms
from torch import nn
from torch.nn import functional as F

import numpy as np
import matplotlib.pyplot as plt

import utils
import utils.data, utils.ML, utils.models
from utils.models import number_of_parameters

import copy

In [2]:
device = "mps"

In [8]:
batch_size = 32
num_workers = 4

loss_fn = nn.CrossEntropyLoss()

In [6]:
transform_CIFAR = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

train_dataloader, val_dataloader, test_dataloader = utils.data.get_CIFAR_data_loaders(batch_size, transform_CIFAR, num_workers=0)

alexnet = models.alexnet(weights=models.AlexNet_Weights.IMAGENET1K_V1)

In [4]:
class upsampled_classifier(nn.Module):
    
    def __init__(self):
        super(upsampled_classifier, self).__init__()
        
        self.upsampling = nn.Sequential(nn.ConvTranspose2d(3, 3, (2, 2), stride=2, bias=False),
                                       nn.ConvTranspose2d(3, 3, (2, 2), stride=2, bias=False),
                                       nn.ConvTranspose2d(3, 3, (2, 2), stride=2, bias=False)) # out: 3x256x256
        
        self.upsampling[0].weight.data.fill_(1.0)
        self.upsampling[1].weight.data.fill_(1.0)
        self.upsampling[2].weight.data.fill_(1.0)
        
        self.features = alexnet.features  # output: 256x6x6
        
        self.avg_pooling = nn.AdaptiveAvgPool2d(output_size=(6,6))
        self.flatten = nn.Flatten()
        
        self.classifier = nn.Sequential(nn.Linear(9216, 1024),
                                        nn.ReLU(),
                                        nn.Linear(1024, 10))
        
    def forward(self, x):

        x = self.upsampling(x)
        x = self.features(x)
        x = self.avg_pooling(x)
        x = self.flatten(x)
        x = self.classifier(x)

        return x

In [9]:
model = upsampled_classifier().to(device)

for param in model.features.parameters():
    param.requires_grad = False

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

num_epochs = 5

train_loss_log, val_loss_log = utils.ML.train_model(model, train_dataloader, val_dataloader, optimizer, loss_fn, num_epochs, device, verbose=True)


 epoch:  1, training loss: inf, validation loss 1.081, validation accuracy 0.619


KeyboardInterrupt: 

In [None]:
X, y = next(iter(test_dataloader))

X_upsample = model.upsampling(X.to("mps"))

print(X_upsample.shape)

In [None]:
i = 1

im = X_upsample[i].detach().to('cpu')

print(im.min(), im.max())

im = (im - im.min())/(im.max() - im.min())

print(im.min(), im.max())

plt.figure()
plt.imshow(np.transpose(im, (1, 2, 0)))