In [20]:
import torch
import torchvision.models as models

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [21]:
from torchvision import datasets, models, transforms

data_transforms_train = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

data_transforms_test = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])


In [22]:
batch_size = 32

train_loader = torch.utils.data.DataLoader(
    datasets.CIFAR10('data', train=True, download=True,
                    transform=data_transforms_train),
    batch_size=batch_size, shuffle=True
)

test_loader = torch.utils.data.DataLoader(
    datasets.CIFAR10('data', train=False, transform=data_transforms_test),
    batch_size=1
)

Files already downloaded and verified


# ResNet as feature extractor

In [27]:
# !pip install torch-summary
from torchsummary import summary

full_model_resnet = models.resnet18(pretrained=True)
# summary(model, (3, 224, 224))
print(full_model_resnet)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [49]:
import torch.nn as nn

model_resnet = torch.nn.Sequential(*(list(full_model_resnet.children())[:-2]))

print(model_resnet[-1])
print("LAST: ", model_resnet[-1][-1].conv2)


Sequential(
  (0): BasicBlock(
    (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (downsample): Sequential(
      (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
      (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (1): BasicBlock(
    (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(512, eps=1

In [None]:
for param in model_resnet.parameters():
    param.requires_grad = False

# Parameters of newly constructed modules have requires_grad=True by default
# num_ftrs = model_resnet.fc.in_features
# model_resnet.fc = nn.Linear(num_ftrs, 100)

model_resnet = model_resnet.to(device)

In [15]:
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model_resnet.parameters(), lr=0.001)

512

In [52]:
import torch.nn.functional as F

class ConvAE(nn.Module):
    def __init__(self):
        super(ConvAE, self).__init__()

        # For MSE Loss
        self.decoder = nn.Sequential(             
            nn.ConvTranspose2d(512, 256, kernel_size = 3),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 3, kernel_size = 3),
            nn.ReLU(True)
        )

    def forward(self, x):
        x = self.decoder(x)
        return x
    



In [53]:
model_ae = ConvAE().to(device)
model_ae 

ConvAE(
  (decoder): Sequential(
    (0): ConvTranspose2d(512, 256, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU(inplace=True)
    (2): ConvTranspose2d(256, 3, kernel_size=(3, 3), stride=(1, 1))
    (3): ReLU(inplace=True)
  )
)

In [None]:
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model_ae.parameters(), lr=0.001)

n_epochs = 10

for epoch in range(n_epochs):
    train_loss = 0.0

    for images, targets in train_loader:
        optimizer.zero_grad()

        x = model_resnet.train()(images)
        predictions = model_ae.train()(images)
    
        loss = criterion(predictions, targets)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()*images.size(0)

    print(f'Epoch {epoch+1} \t\t Training Loss: {train_loss / len(train_loader)}')

