In [1]:
from PIL import Image
from transformers import BertTokenizer, ViTImageProcessor, VisionEncoderDecoderModel
from split import split_dataset
from songs_dataset import SongsDataset
from torch.utils.data import DataLoader
import torch
import torch.optim as optim
import torch.nn.functional as F

In [3]:
# Splitting the data
all_songs_path = 'data/songs/all'
split_dataset(all_songs_path)

# Loading image_processor to be used in datasets
image_processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")

In [4]:
# Datasets
train_dataset = SongsDataset('data/songs/train', image_processor)
validation_dataset = SongsDataset('data/songs/validation', image_processor)
test_dataset = SongsDataset('data/songs/test', image_processor)

In [5]:
# Dataloaders
BATCH_SIZE = 32 
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
validation_dataloader = DataLoader(validation_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [7]:
# Pretrained
model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(
    "google/vit-base-patch16-224-in21k", "bert-base-uncased"
)
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model.config.decoder_start_token_id = tokenizer.cls_token_id
model.config.pad_token_id = tokenizer.pad_token_id

Some weights of BertLMHeadModel were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['bert.encoder.layer.4.crossattention.output.dense.weight', 'bert.encoder.layer.5.crossattention.self.key.bias', 'bert.encoder.layer.11.crossattention.output.dense.weight', 'bert.encoder.layer.8.crossattention.self.key.bias', 'bert.encoder.layer.9.crossattention.self.value.weight', 'bert.encoder.layer.11.crossattention.output.dense.bias', 'bert.encoder.layer.0.crossattention.self.value.bias', 'bert.encoder.layer.8.crossattention.output.dense.bias', 'bert.encoder.layer.8.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.10.crossattention.self.query.weight', 'bert.encoder.layer.4.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.8.crossattention.self.query.bias', 'bert.encoder.layer.10.crossattention.self.key.bias', 'bert.encoder.layer.5.crossattention.self.value.bias', 'bert.encoder.layer.1.crossattention.output.LayerNorm.bias', 'bert.encod

In [8]:
# Define your optimizer
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# Training loop
num_epochs = 5
for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0
    
    for batch in train_dataloader:
        optimizer.zero_grad()

        images = batch['image']
        lyrics = batch['lyrics']
        
        # Forward pass
        outputs = model(pixel_values=images, labels=lyrics)
        logits = outputs.logits
        
        # Calculate the loss
        loss = F.cross_entropy(logits.view(-1, logits.shape[-1]), lyrics.view(-1))
        
        # Backpropagation and optimization step
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()

    average_loss = total_loss / len(train_dataloader)
    print(f"Epoch {epoch+1}/{num_epochs} - Loss: {average_loss:.4f}")

ValueError: You have to specify pixel_values