In [1]:
import os
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import scipy.io as sio
import kornia
import random
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
from matplotlib import pyplot as plt
%matplotlib inline
# os.environ["CUDA_VISIBLE_DEVICES"] = "0,2"

In [3]:
dataset_name = 'imagenet'

In [4]:
device = torch.device('cuda' if torch.cuda.is_available else 'cpu')


In [5]:
def plot(img):
    plt.figure(figsize=(5,5))
    plt.imshow(img, cmap='gray')
    plt.show()

In [6]:
PIL2Ten = torchvision.transforms.ToTensor()
Ten2PIL = torchvision.transforms.ToPILImage()

In [7]:
class TensorDataset(torch.utils.data.Dataset):
    def __init__(self, dataset_name, root, train, download, transform, target_transform=None):
        super(TensorDataset, self).__init__()
        self.transform = transform
        if dataset_name.lower() == 'mnist':
            self.dataset = torchvision.datasets.MNIST(root=root, train=train, download=download)
        elif dataset_name.lower() == 'fmnist':
            self.dataset = torchvision.datasets.FashionMNIST(root=root, train=train, download=download)
        elif dataset_name.lower() == 'stl10':
            split_name = 'train' if train == True else 'test'
            self.dataset = torchvision.datasets.STL10(root=root, split=split_name, download=download)
        elif dataset_name.lower() == 'cifar10':
            self.dataset = torchvision.datasets.CIFAR10(root=root, train=train, download=download)
        elif dataset_name.lower() == 'cifar100':
            self.dataset = torchvision.datasets.CIFAR100(root=root, train=train, download=download)      
        elif dataset_name.lower() == 'imagenet':
            self.dataset = torchvision.datasets.ImageFolder(os.path.join('data', 'ILSVRC2012', 'train' if train else 'val'))
            self.transform = torchvision.transforms.Compose(
                [torchvision.transforms.Resize((224, 224)), transform]
            )
        else:
            raise NotImplementedError("Only MNIST, FashionMNIST(fmnist) and STL10 supported for now!")
    
    def __len__(self):
        return self.dataset.__len__()
    
    def __getitem__(self, index):
        pil_img, label = self.dataset[index]
        return (self.transform(pil_img), torch.Tensor([label]))

In [8]:
def getModel(dataset_name):
    if dataset_name.lower() in ['mnist' , 'fmnist']:
        model = nn.Sequential(nn.Conv2d(1, 16, 3),
                            nn.BatchNorm2d(16),
                            nn.ReLU(),
                            nn.MaxPool2d(2),
                             
                            nn.Conv2d(16, 32, 3),
                            nn.BatchNorm2d(32),
                            nn.ReLU(),
                            nn.MaxPool2d(2),
                             
                            nn.Conv2d(32, 64, 2),
                            nn.BatchNorm2d(64),
                            nn.ReLU(),
                            nn.MaxPool2d(2),
                            
                            nn.Flatten(),
                            nn.Dropout(0.5),
                            nn.Linear(64*2*2, 64),
                            nn.Linear(64, 10)
                            )
    elif dataset_name.lower() in ['stl10', 'cifar10', 'cifar100']:
        model = torchvision.models.vgg11_bn(num_classes=10)
    elif dataset_name.lower() in ['imagenet']:
        model = torchvision.models.vgg19_bn(pretrained=True)
        # model = nn.Sequential(
        #     nn.Conv2d(3, 64, 11, 4, padding=0),
        #     nn.ReLU(),
        #     nn.LocalResponseNorm(size=5, k=2),
        #     nn.MaxPool2d(2), 

        #     nn.Conv2d(64, 256, 5, 1, padding=2),
        #     nn.ReLU(),
        #     nn.LocalResponseNorm(size=5, k=2),
        #     nn.MaxPool2d(2),

        #     nn.Conv2d(256, 256, 3, 1, padding=1),
        #     nn.ReLU(),

        #     nn.Conv2d(256, 256, 3, 1, padding=1),
        #     nn.ReLU(),

        #     nn.Conv2d(256, 256, 3, 1, padding=1),
        #     nn.ReLU(),
        #     nn.MaxPool2d(2),

        #     nn.Flatten(),

        #     nn.Linear(256 * 6 * 6, 4096),
        #     nn.ReLU(True),
        #     nn.Dropout(),
        #     nn.Linear(4096, 4096),
        #     nn.ReLU(True),
        #     nn.Dropout(),
        #     nn.Linear(4096, 1000)
        # ) # VGG-F as in original AAA paper
    else:
        raise NotImplementedError("Only MNIST, FashionMNIST, CIFAR and STL10 datasets supported!")
    return model

In [60]:
class UNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.lin = nn.Linear(in_features=10, out_features=16)
        self.deconv = nn.Sequential(
            nn.ConvTranspose2d(1, 3, 5, stride=2),
            nn.BatchNorm2d(3),
            nn.ReLU(),
            nn.ConvTranspose2d(3, 3, 3, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(3, 3, 3, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(3, 3, 3, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(3, 3, 3, stride=2)
        )
        self.conv = nn.Sequential(
            nn.Conv2d(3, 1, 3, stride=2),
            nn.Conv2d(1, 1, 3, stride=2),
            nn.Conv2d(1, 1, 3, stride=2),
            nn.Conv2d(1, 1, 3, stride=2),
            nn.Conv2d(1, 1, 5, stride=2),
        )
        
    def forward(self, x):
        return self.conv(x)
#         feat = self.lin(x).reshape(-1, 1, 4, 4)
#         return self.deconv(feat)

In [61]:
gen = UNet()

In [57]:
gen(torch.zeros((1, 3, 128, 128))).shape

torch.Size([1, 1, 2, 2])