In [1]:
# !nvidia-smi

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from datasets import load_dataset, DatasetDict
from einops import rearrange, einsum

import torch.nn as nn
from torch.utils.data import DataLoader, Dataset, TensorDataset
import torchvision.models as models
from torchvision.models import resnet18
from torchvision import transforms

In [2]:
# internal_model = models.resnet18
# internal_weights = models.ResNet18_Weights.IMAGENET1K_V1

internal_model = models.resnet34
internal_weights = models.ResNet34_Weights.IMAGENET1K_V1

# internal_model = models.resnet50
# internal_weights = models.ResNet50_Weights.IMAGENET1K_V1

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
# Load images
x_train = np.load('x_train.npy')
x_test = np.load('x_test.npy')

y_train = np.load('y_train.npy')
y_test = np.load('y_test.npy')

## View some images
# plt.imshow(x_train[2,:,:,: ] )
# plt.axis('off')
# plt.show()

# # convert to torch
# x_train = torch.from_numpy(x_train)
# x_test = torch.from_numpy(x_test)

# y_train = torch.from_numpy(y_train)
# y_test = torch.from_numpy(y_test)

print('X_train shape:\t' ,x_train.shape)
print('Y_train shape\t' ,y_train.shape)

print('X_test shape\t' ,x_test.shape)
print('Y_test shape\t' ,y_test.shape)

train_cut = 300
x_train = x_train[:train_cut]
y_train = y_train[:train_cut]

test_cut = 100
x_test = x_test[:test_cut]
y_test = y_test[:test_cut]

# print('X_train shape:\t' ,x_train.shape)
# print('X_train dtype:\t' ,x_train.dtype)
# print('X_train type:\t' ,type(x_train))


X_train shape:	 (791, 250, 250, 3)
Y_train shape	 (791,)
X_test shape	 (784, 250, 250, 3)
Y_test shape	 (784,)


In [4]:
transform_fn = internal_weights.transforms()
convert_to_tensor = transforms.ToTensor()

print('Transforms:\n', transform_fn)

def process_image(image):
    image = convert_to_tensor(image)
    # image = rearrange(image, 'h w c -> c h w')
    image = transform_fn(image)
    return image

# transform images
x_train_transformed = list(map(process_image, x_train))
x_test_transformed = list(map(process_image, x_test))

# stack images
x_train_tensor = torch.stack(x_train_transformed) #.to(device)
x_test_tensor = torch.stack(x_test_transformed) #.to(device)

# convert labels to tensor
y_train_tensor = torch.tensor(y_train) #.to(device)
y_test_tensor = torch.tensor(y_test) #.to(device)

# add dimension to labels
# y_train_tensor = y_train_tensor.unsqueeze(1)
# y_test_tensor = y_test_tensor.unsqueeze(1)

# TensorDataset
train_data = TensorDataset(x_train_tensor, y_train_tensor)
test_data = TensorDataset(x_test_tensor, y_test_tensor)

# DataLoader
batch_size = 32
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)


Transforms:
 ImageClassification(
    crop_size=[224]
    resize_size=[256]
    mean=[0.485, 0.456, 0.406]
    std=[0.229, 0.224, 0.225]
    interpolation=InterpolationMode.BILINEAR
)


In [5]:
sample, label = next(iter(train_loader))
print('Sample shape:', sample.shape)
print('Label shape:', label.shape)

Sample shape: torch.Size([32, 3, 224, 224])
Label shape: torch.Size([32])


### Model

In [7]:
class BiLinearModel(nn.Module):
    def __init__(self, num_classes):
        super(BiLinearModel, self).__init__()
        
        self.cnn1 = internal_model(weights=internal_weights)
        self.cnn2 = internal_model(weights=internal_weights)
        

        self.cnn1 = nn.Sequential(*list(self.cnn1.children())[:-2])
        self.cnn2 = nn.Sequential(*list(self.cnn2.children())[:-2])

        self.feature_size = internal_model(weights=internal_weights).fc.in_features
        # print('Feature size:', self.feature_size)

        # Define bilinear pooling
        self.fc = nn.Linear(self.feature_size**2, num_classes) 
        # nn.Sequential(
        #     nn.Linear(512*512, 512),
        #     nn.ReLU(),
        #     nn.Linear(512, num_classes)
        # )
    
    def forward(self, x):
        x1 = self.cnn1(x)
        x2 = self.cnn2(x)
        
        # # Bilinear pooling
        # batch_size = x1.size(0)
        # x1 = x1.view(batch_size, 512, 49)
        # x2 = x2.view(batch_size, 512, 49)
        # x = torch.bmm(x1, x2.transpose(1, 2)) / 49
        # x = x.view(batch_size, 512*512)
        # print('X1 shape:', x1.shape)
        
        # bilinear pooling with einops
        x1 = rearrange(x1, 'b k h w -> b k (h w)')
        x2 = rearrange(x2, 'b k h w -> b k (h w)')
        x = einsum(x1, x2, 'b i j, b k j -> b i k')
        # print('X shape:', x.shape)
        x = rearrange(x, 'b i j -> b (i j)')
        # print('X shape:', x.shape)

        x = self.fc(x)
        return x

model = BiLinearModel(num_classes=20).to(device)

in_tensor = torch.randn(1, 3, 224, 224).to(device)
model(in_tensor).shape


torch.Size([1, 20])

In [8]:
import torch.optim as optim
from torch.optim import lr_scheduler
from tqdm import tqdm 

# Freeze the weights of the pre-trained models
for param in model.cnn1.parameters():
    param.requires_grad = False
for param in model.cnn2.parameters():
    param.requires_grad = False

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.fc.parameters(), lr=0.001, momentum=0.9)
scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

# Training function
def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        with tqdm(total=len(train_loader), desc=f"Epoch {epoch+1}/{num_epochs}") as pbar:
            for sample in train_loader:
                image, label = sample
                image, label = image.to(device), label.to(device)
                optimizer.zero_grad()
                outputs = model(image)
                loss = criterion(outputs, label)
                loss.backward()
                optimizer.step()
                # update progress bar
                running_loss += loss.item()*image.size(0)

                # accuracy
                _, preds = torch.max(outputs, 1)
                corrects = torch.sum(preds == label.data)
                accuracy = corrects.double() / image.size(0)
                
                pbar.set_postfix(loss=running_loss/len(train_loader.dataset), accuracy=accuracy.item())
                pbar.update(1)
            
            # validation loss
            model.eval()
            with torch.no_grad():
                val_loss = 0.0
                for sample in test_loader:
                    image, label = sample
                    image, label = image.to(device), label.to(device)
                    outputs = model(image)
                    loss = criterion(outputs, label)
                    val_loss += loss.item()*image.size(0)
            scheduler.step()
            epoch_loss = running_loss / len(train_loader.dataset)
            print(f"Epoch {epoch+1}/{num_epochs} loss: {epoch_loss:.4f}")

    return model

# Train the model
model = train_model(model, criterion, optimizer, scheduler, num_epochs=10)

# Unfreeze the weights and train again
for param in model.cnn1.parameters():
    param.requires_grad = True
for param in model.cnn2.parameters():
    param.requires_grad = True

optimizer = optim.SGD(model.parameters(), lr=0.0001, momentum=0.9)
scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

# Train the model again
model = train_model(model, criterion, optimizer, scheduler, num_epochs=15)


Epoch 1/10:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 1/10: 100%|██████████| 10/10 [00:02<00:00,  3.90it/s, accuracy=0.417, loss=2.66e+5]


Epoch 1/10 loss: 266064.7557


Epoch 2/10: 100%|██████████| 10/10 [00:02<00:00,  4.49it/s, accuracy=0.667, loss=1.59e+5]


Epoch 2/10 loss: 159159.8666


Epoch 3/10: 100%|██████████| 10/10 [00:02<00:00,  4.49it/s, accuracy=0.917, loss=4.71e+4]


Epoch 3/10 loss: 47141.9481


Epoch 4/10:  50%|█████     | 5/10 [00:00<00:00,  5.66it/s, accuracy=1, loss=1.26e+4]    

In [None]:
# bilinear pooling with einops
# x1 = rearrange(x1, 'b k h w -> b k (h w)')
# x2 = rearrange(x2, 'b k h w -> b k (h w)')
# x = einsum('b i j, b k j -> b i k', x1, x2)
