In [1]:
# !nvidia-smi

In [2]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from datasets import load_dataset, DatasetDict
from einops import rearrange, einsum

import torch.nn as nn
from torch.utils.data import DataLoader, Dataset, TensorDataset
import torchvision.models as models
from torchvision.models import resnet18
from torchvision import transforms

In [3]:
# internal_model = models.resnet18
# internal_weights = models.ResNet18_Weights.IMAGENET1K_V1

internal_model = models.resnet34
internal_weights = models.ResNet34_Weights.IMAGENET1K_V1

# internal_model = models.resnet50
# internal_weights = models.ResNet50_Weights.IMAGENET1K_V1

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
# Load images
x_train = np.load('x_train.npy')
x_test = np.load('x_test.npy')

y_train = np.load('y_train.npy')
y_test = np.load('y_test.npy')

## View some images
# plt.imshow(x_train[2,:,:,: ] )
# plt.axis('off')
# plt.show()

# # convert to torch
# x_train = torch.from_numpy(x_train)
# x_test = torch.from_numpy(x_test)

# y_train = torch.from_numpy(y_train)
# y_test = torch.from_numpy(y_test)

print('X_train shape:\t' ,x_train.shape)
print('Y_train shape\t' ,y_train.shape)

print('X_test shape\t' ,x_test.shape)
print('Y_test shape\t' ,y_test.shape)

# train_cut = 300
# x_train = x_train[:train_cut]
# y_train = y_train[:train_cut]

# test_cut = 100
# x_test = x_test[:test_cut]
# y_test = y_test[:test_cut]

# print('X_train shape:\t' ,x_train.shape)
# print('X_train dtype:\t' ,x_train.dtype)
# print('X_train type:\t' ,type(x_train))


X_train shape:	 (791, 250, 250, 3)
Y_train shape	 (791,)
X_test shape	 (784, 250, 250, 3)
Y_test shape	 (784,)


In [5]:
transform_fn = internal_weights.transforms()
convert_to_tensor = transforms.ToTensor()

print('Transforms:\n', transform_fn)

def process_image(image):
    image = convert_to_tensor(image)
    # image = rearrange(image, 'h w c -> c h w')
    image = transform_fn(image)
    return image

# transform images
x_train_transformed = list(map(process_image, x_train))
x_test_transformed = list(map(process_image, x_test))

# stack images
x_train_tensor = torch.stack(x_train_transformed) #.to(device)
x_test_tensor = torch.stack(x_test_transformed) #.to(device)

# convert labels to tensor
y_train_tensor = torch.tensor(y_train) #.to(device)
y_test_tensor = torch.tensor(y_test) #.to(device)

y_train_tensor = y_train_tensor - 1
y_test_tensor = y_test_tensor - 1

# add dimension to labels
# y_train_tensor = y_train_tensor.unsqueeze(1)
# y_test_tensor = y_test_tensor.unsqueeze(1)

# TensorDataset
train_data = TensorDataset(x_train_tensor, y_train_tensor)
test_data = TensorDataset(x_test_tensor, y_test_tensor)

# DataLoader
batch_size = 32
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, drop_last=True)


Transforms:
 ImageClassification(
    crop_size=[224]
    resize_size=[256]
    mean=[0.485, 0.456, 0.406]
    std=[0.229, 0.224, 0.225]
    interpolation=InterpolationMode.BILINEAR
)


In [6]:
sample, label = next(iter(train_loader))
print('Sample shape:', sample.shape)
print('Label shape:', label.shape)

Sample shape: torch.Size([32, 3, 224, 224])
Label shape: torch.Size([32])


### Model

In [7]:
class BiLinearModel(nn.Module):
    def __init__(self, num_classes):
        super(BiLinearModel, self).__init__()
        
        self.cnn1 = internal_model(weights=internal_weights)
        self.cnn2 = internal_model(weights=internal_weights)
        

        self.cnn1 = nn.Sequential(*list(self.cnn1.children())[:-2])
        self.cnn2 = nn.Sequential(*list(self.cnn2.children())[:-2])

        self.feature_size = internal_model(weights=internal_weights).fc.in_features
        # print('Feature size:', self.feature_size)

        # Define bilinear pooling
        self.fc = nn.Linear(self.feature_size**2, num_classes) 
        # nn.Sequential(
        #     nn.Linear(512*512, 512),
        #     nn.ReLU(),
        #     nn.Linear(512, num_classes)
        # )
    
    def forward(self, x):
        x1 = self.cnn1(x)
        x2 = self.cnn2(x)
        
        # bilinear pooling with einops
        x1 = rearrange(x1, 'b k h w -> b k (h w)')
        x2 = rearrange(x2, 'b k h w -> b k (h w)')
        x = einsum(x1, x2, 'b i j, b k j -> b i k')
        # print('X shape:', x.shape)
        x = rearrange(x, 'b i j -> b (i j)')
        # print('X shape:', x.shape)

        x = self.fc(x)
        return x

model = BiLinearModel(num_classes=20)
# model = model.to(device)

in_tensor = torch.randn(1, 3, 224, 224)#.to(device)
model(in_tensor).shape


torch.Size([1, 20])

In [8]:
import torch.optim as optim
from torch.optim import lr_scheduler
from tqdm import tqdm 

# Training function
def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        with tqdm(total=len(train_loader), desc=f"Epoch {epoch+1}/{num_epochs}") as pbar:
            for sample in train_loader:
                image, label = sample
                # image, label = image.to(device), label.to(device)
                optimizer.zero_grad()
                outputs = model(image)
                loss = criterion(outputs, label)
                loss.backward()
                optimizer.step()
                # update progress bar
                running_loss += loss.item()*image.size(0)

                # accuracy
                _, preds = torch.max(outputs, 1)
                corrects = torch.sum(preds == label.data)
                accuracy = corrects.double() / image.size(0)
                
                pbar.set_postfix(loss=running_loss/len(train_loader.dataset), accuracy=accuracy.item())
                pbar.update(1)
        
        scheduler.step()
        
        # validation
        # model.eval()
        # with torch.no_grad():
        #     with tqdm(total=len(test_loader), desc=f"Validation") as pbar:
        #         val_loss = 0.0
        #         for sample in test_loader:
        #             image, label = sample
        #             image, label = image.to(device), label.to(device)
        #             outputs = model(image)
        #             loss = criterion(outputs, label)
        #             val_loss += loss.item()*image.size(0)
        #             # accuracy
        #             _, preds = torch.max(outputs, 1)
        #             corrects = torch.sum(preds == label.data)
        #             accuracy = corrects.double() / image.size(0)
        #             pbar.set_postfix(loss=val_loss/len(test_loader.dataset), accuracy=accuracy.item())
        #             pbar.update(1)
    return model

# Freeze the weights of the pre-trained models
for param in model.cnn1.parameters():
    param.requires_grad = False
for param in model.cnn2.parameters():
    param.requires_grad = False

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.fc.parameters(), lr=1e-4, momentum=0.9)
scheduler = lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

# Train the model
model = train_model(model, criterion, optimizer, scheduler, num_epochs=10)

# Unfreeze the weights and train again
for param in model.cnn1.parameters():
    param.requires_grad = True
for param in model.cnn2.parameters():
    param.requires_grad = True

optimizer = optim.SGD(model.parameters(), lr=1e-6, momentum=0.9)
scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

# Train the model again
model = train_model(model, criterion, optimizer, scheduler, num_epochs=15)

# evaluate the model
model.eval()
corrects = 0
total = 0
with torch.no_grad():
    for sample in test_loader:
        image, label = sample
        image, label = image.to(device), label.to(device)
        outputs = model(image)
        _, preds = torch.max(outputs, 1)
        corrects += torch.sum(preds == label.data)
        total += image.size(0)
print(f"Accuracy: {corrects.double()/total}")


Epoch 1/10:   0%|          | 0/24 [00:00<?, ?it/s]

Epoch 1/10:   4%|▍         | 1/24 [00:04<01:34,  4.10s/it, accuracy=0.0312, loss=3.84]

Preds: tensor(10)
Label: tensor(9)


Epoch 1/10:   8%|▊         | 2/24 [00:08<01:35,  4.34s/it, accuracy=0.156, loss=188]  

Preds: tensor(4)
Label: tensor(18)


Epoch 1/10:  12%|█▎        | 3/24 [00:12<01:27,  4.18s/it, accuracy=0.0625, loss=786]

Preds: tensor(6)
Label: tensor(6)


Epoch 1/10:  17%|█▋        | 4/24 [00:16<01:22,  4.12s/it, accuracy=0.0312, loss=1.9e+3]

Preds: tensor(13)
Label: tensor(18)


Epoch 1/10:  21%|██        | 5/24 [00:20<01:18,  4.11s/it, accuracy=0.125, loss=3.48e+3]

Preds: tensor(16)
Label: tensor(15)


Epoch 1/10:  25%|██▌       | 6/24 [00:25<01:15,  4.19s/it, accuracy=0.125, loss=4.92e+3]

Preds: tensor(19)
Label: tensor(12)


Epoch 1/10:  29%|██▉       | 7/24 [00:29<01:10,  4.17s/it, accuracy=0.125, loss=6.86e+3]

Preds: tensor(9)
Label: tensor(2)


Epoch 1/10:  33%|███▎      | 8/24 [00:33<01:08,  4.30s/it, accuracy=0.344, loss=8.31e+3]

Preds: tensor(18)
Label: tensor(18)


Epoch 1/10:  38%|███▊      | 9/24 [00:37<01:03,  4.26s/it, accuracy=0.219, loss=1.01e+4]

Preds: tensor(14)
Label: tensor(11)


Epoch 1/10:  42%|████▏     | 10/24 [00:42<00:58,  4.20s/it, accuracy=0.0938, loss=1.29e+4]

Preds: tensor(5)
Label: tensor(16)


Epoch 1/10:  46%|████▌     | 11/24 [00:46<00:53,  4.14s/it, accuracy=0.0938, loss=1.65e+4]

Preds: tensor(8)
Label: tensor(10)


Epoch 1/10:  50%|█████     | 12/24 [00:50<00:49,  4.11s/it, accuracy=0.188, loss=1.92e+4] 

Preds: tensor(12)
Label: tensor(6)


Epoch 1/10:  54%|█████▍    | 13/24 [00:54<00:45,  4.10s/it, accuracy=0.219, loss=2.12e+4]

Preds: tensor(2)
Label: tensor(6)


Epoch 1/10:  58%|█████▊    | 14/24 [00:58<00:40,  4.06s/it, accuracy=0.219, loss=2.47e+4]

Preds: tensor(11)
Label: tensor(8)


Epoch 1/10:  62%|██████▎   | 15/24 [01:02<00:36,  4.05s/it, accuracy=0.188, loss=2.7e+4] 

Preds: tensor(15)
Label: tensor(15)


Epoch 1/10:  67%|██████▋   | 16/24 [01:06<00:32,  4.10s/it, accuracy=0.219, loss=2.89e+4]

Preds: tensor(15)
Label: tensor(7)


Epoch 1/10:  71%|███████   | 17/24 [01:10<00:28,  4.11s/it, accuracy=0.438, loss=3.04e+4]

Preds: tensor(0)
Label: tensor(0)


Epoch 1/10:  75%|███████▌  | 18/24 [01:14<00:24,  4.10s/it, accuracy=0.0625, loss=3.26e+4]

Preds: tensor(16)
Label: tensor(17)


Epoch 1/10:  79%|███████▉  | 19/24 [01:18<00:20,  4.12s/it, accuracy=0.312, loss=3.39e+4] 

Preds: tensor(1)
Label: tensor(17)


Epoch 1/10:  83%|████████▎ | 20/24 [01:22<00:16,  4.13s/it, accuracy=0.281, loss=3.57e+4]

Preds: tensor(3)
Label: tensor(3)


Epoch 1/10:  88%|████████▊ | 21/24 [01:26<00:12,  4.10s/it, accuracy=0.188, loss=3.76e+4]

Preds: tensor(9)
Label: tensor(6)


Epoch 1/10:  92%|█████████▏| 22/24 [01:31<00:08,  4.14s/it, accuracy=0.312, loss=3.88e+4]

Preds: tensor(8)
Label: tensor(18)


Epoch 1/10:  96%|█████████▌| 23/24 [01:35<00:04,  4.16s/it, accuracy=0.438, loss=3.98e+4]

Preds: tensor(17)
Label: tensor(17)


Epoch 1/10: 100%|██████████| 24/24 [01:39<00:00,  4.14s/it, accuracy=0.188, loss=4.13e+4]


Preds: tensor(17)
Label: tensor(15)


Epoch 2/10:   4%|▍         | 1/24 [00:03<01:31,  3.96s/it, accuracy=0.406, loss=730]

Preds: tensor(12)
Label: tensor(12)


Epoch 2/10:   8%|▊         | 2/24 [00:07<01:27,  3.99s/it, accuracy=0.531, loss=1.17e+3]

Preds: tensor(8)
Label: tensor(8)


Epoch 2/10:  12%|█▎        | 3/24 [00:12<01:24,  4.03s/it, accuracy=0.438, loss=1.46e+3]

Preds: tensor(1)
Label: tensor(3)


Epoch 2/10:  17%|█▋        | 4/24 [00:16<01:20,  4.03s/it, accuracy=0.5, loss=1.91e+3]  

Preds: tensor(15)
Label: tensor(16)


Epoch 2/10:  21%|██        | 5/24 [00:20<01:16,  4.05s/it, accuracy=0.531, loss=2.35e+3]

Preds: tensor(18)
Label: tensor(18)


Epoch 2/10:  25%|██▌       | 6/24 [00:24<01:13,  4.07s/it, accuracy=0.5, loss=2.98e+3]  

Preds: tensor(9)
Label: tensor(7)


Epoch 2/10:  29%|██▉       | 7/24 [00:28<01:08,  4.04s/it, accuracy=0.594, loss=3.31e+3]

Preds: tensor(3)
Label: tensor(3)


Epoch 2/10:  33%|███▎      | 8/24 [00:32<01:04,  4.03s/it, accuracy=0.5, loss=3.84e+3]  

Preds: tensor(19)
Label: tensor(19)


Epoch 2/10:  38%|███▊      | 9/24 [00:36<01:00,  4.04s/it, accuracy=0.562, loss=4.22e+3]

Preds: tensor(2)
Label: tensor(13)


Epoch 2/10:  42%|████▏     | 10/24 [00:40<00:56,  4.02s/it, accuracy=0.688, loss=4.43e+3]

Preds: tensor(19)
Label: tensor(11)


Epoch 2/10:  46%|████▌     | 11/24 [00:44<00:52,  4.07s/it, accuracy=0.656, loss=4.66e+3]

Preds: tensor(9)
Label: tensor(9)


Epoch 2/10:  50%|█████     | 12/24 [00:49<00:52,  4.35s/it, accuracy=0.531, loss=5.08e+3]

Preds: tensor(19)
Label: tensor(19)


Epoch 2/10:  54%|█████▍    | 13/24 [00:53<00:47,  4.31s/it, accuracy=0.406, loss=5.62e+3]

Preds: tensor(7)
Label: tensor(6)


Epoch 2/10:  58%|█████▊    | 14/24 [00:57<00:42,  4.27s/it, accuracy=0.594, loss=6.04e+3]

Preds: tensor(7)
Label: tensor(16)


Epoch 2/10:  62%|██████▎   | 15/24 [01:02<00:38,  4.25s/it, accuracy=0.688, loss=6.53e+3]

Preds: tensor(9)
Label: tensor(9)


Epoch 2/10:  67%|██████▋   | 16/24 [01:06<00:33,  4.24s/it, accuracy=0.594, loss=6.89e+3]

Preds: tensor(12)
Label: tensor(13)


Epoch 2/10:  71%|███████   | 17/24 [01:10<00:29,  4.24s/it, accuracy=0.594, loss=7.28e+3]

Preds: tensor(19)
Label: tensor(19)


Epoch 2/10:  75%|███████▌  | 18/24 [01:14<00:25,  4.27s/it, accuracy=0.5, loss=7.59e+3]  

Preds: tensor(16)
Label: tensor(17)


In [None]:
# bilinear pooling with einops
# x1 = rearrange(x1, 'b k h w -> b k (h w)')
# x2 = rearrange(x2, 'b k h w -> b k (h w)')
# x = einsum('b i j, b k j -> b i k', x1, x2)
