In [1]:
import torch 
from torch import nn 
from torchvision import transforms 
from torch.utils.data import DataLoader
from models import ResformerEncoder, ResformerDecoder
from dataset import AmazonImageData
from trainer import model_trainer
import tokens
from pathlib import Path

In [None]:
if torch.cuda.is_available():
    device = 'cuda'
else:
    device = 'cpu'
device

In [3]:
# Hyperparameters 

LEARNING_RATE = 1e-5
BATCH_SIZE = 16
NUM_EPOCHS = 1
ENCODER_SAVE_PATH = 'models/first_encoder.pth'
DECODER_SAVE_PATH = 'models/first_decoder.pth'

In [4]:
input_tokens, output_tokens = tokens.get_tokens()

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize(size=(400, 400)),
    transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
])

In [None]:
root = 'amazon files/data/train.csv'
img_root = 'dataset/train'
data = AmazonImageData(root=root, img_root=img_root, input_tokens=input_tokens, output_tokens=output_tokens, max_seq_len=64, transform=transform)

dataloader = DataLoader(dataset=data, batch_size=BATCH_SIZE, shuffle=False)

len(data), len(dataloader)

In [6]:
# Initiating model instances 

encoder = ResformerEncoder().to(device)
decoder = ResformerDecoder().to(device)

In [None]:
model_file = Path(ENCODER_SAVE_PATH)
if model_file.is_file():
    encoder.load_state_dict(torch.load(f=ENCODER_SAVE_PATH))
    print("1) Exists")
else:
    print("1) Creating")
    
model_file = Path(DECODER_SAVE_PATH)
if model_file.is_file():
    decoder.load_state_dict(torch.load(f=DECODER_SAVE_PATH))
    print("2) Exists")
else:
    print("2) Creating")

In [None]:
# Loss function, optimizer and gradscaler 

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=LEARNING_RATE)
scaler = torch.cuda.amp.GradScaler()

In [None]:
model_trainer(encoder=encoder, decoder=decoder, dataloader=dataloader, loss_fn=loss_fn, optimizer=optimizer, scaler=scaler, epochs=NUM_EPOCHS, device=device, encoder_save_path=ENCODER_SAVE_PATH, decoder_save_path=DECODER_SAVE_PATH)