In [1]:
import numpy as np
from PIL import Image
import torch

In [2]:
from dataset import CustomImageDataset
from torchvision.transforms import Resize

full_ds = CustomImageDataset(list(range(400)), load_embeddings=True)

In [3]:
import torch.nn.functional as F
from torch import nn

class MLP(torch.nn.Module):
    def __init__(self, d_in, d_hidden, d_out):
        super(MLP, self).__init__()
        self.d_in = d_in
        self.linear1 = torch.nn.Linear(d_in, d_hidden)
        self.relu1 = nn.ReLU()
        self.linear2 = torch.nn.Linear(d_hidden, d_out)
        self.relu2 = nn.ReLU()

    def forward(self, X):
        X = X.view(-1, self.d_in)
        X = self.relu1(self.linear1(X))
        X = self.relu2(self.linear2(X))
        return F.softmax(X, dim=1)

In [4]:
mlp = MLP(d_in=1024, d_hidden=128, d_out=4)
print(mlp)

MLP(
  (linear1): Linear(in_features=1024, out_features=128, bias=True)
  (relu1): ReLU()
  (linear2): Linear(in_features=128, out_features=4, bias=True)
  (relu2): ReLU()
)


In [5]:
learning_rate = 1e-4
epochs = 20

In [6]:
from torch.utils.data import DataLoader

train_dataset, val_dataset = torch.utils.data.random_split(full_ds, [0.8, 0.2])

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(train_dataset, batch_size=32, shuffle=False)

In [7]:
from tqdm import tqdm


#model training
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(mlp.parameters(), lr=learning_rate) 

n_total_steps = len(train_loader)
for epoch in range(epochs):
    print(f'== epoch {epoch} ==')
    # Train the model
    mlp.train()
    train_loss = 0.0
    for i, (embedding, label) in enumerate(train_loader):  
        optimizer.zero_grad()
                
        # Forward pass
        outputs =  mlp(embedding)
        loss = criterion(outputs, label)
        
        # Backward and optimize
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
    train_loss /= i
    
    print(f"train_loss : {train_loss}")

    mlp.eval()
    val_loss = 0.0
    for i, (embedding, label) in enumerate(val_loader):                 
        # Forward pass
        outputs =  mlp(embedding)
        loss = criterion(outputs, label)
        val_loss += loss.item()
    val_loss /= i
    print(f"val_loss : {val_loss}")

        

  emb = torch.load(emb_path)


== epoch 0 ==
1.5386125644048054
== epoch 1 ==
1.5260710716247559
== epoch 2 ==
1.5114315218395658
== epoch 3 ==
1.4944257073932223
== epoch 4 ==
1.475412819120619
== epoch 5 ==
1.4541669686635335
== epoch 6 ==
1.4307003551059299
== epoch 7 ==
1.4065180089738634
== epoch 8 ==
1.3807557158999972
== epoch 9 ==
1.3565604554282293
== epoch 10 ==
1.3312537140316434
== epoch 11 ==
1.30704132715861
== epoch 12 ==
1.2854064305623372
== epoch 13 ==
1.2642529673046536
== epoch 14 ==
1.2437139881981745
== epoch 15 ==
1.225196533732944
== epoch 16 ==
1.2080800930658977
== epoch 17 ==
1.1919839514626398
== epoch 18 ==
1.1764198541641235
== epoch 19 ==
1.1628399822447035
