In [110]:
from pathlib import Path
import pandas as pd
import torch
import torch.nn as nn
from transformers import BertModel, BertTokenizer, ViTModel, ViTImageProcessor, Trainer, TrainingArguments
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import pandas as pd

In [111]:
class Config:
    bert_model_name = 'neuralmind/bert-base-portuguese-cased'
    vit_model_name = 'google/vit-base-patch16-224'
    num_classes = 5
    learning_rate = 2e-5
    batch_size = 16
    num_epochs = 10

config = Config()

In [113]:
class MultimodalDataset(Dataset):
    def __init__(self, csv_file, tokenizer: BertTokenizer, processor: ViTImageProcessor, label: str):
        self.data = pd.read_csv(csv_file)
        self.root = Path(csv_file).parent.as_posix()
        self.tokenizer = tokenizer
        self.processor = processor
        self.label = label

    def __len__(self):
        return len(self.data)

    def process_text(self, text: str | list[str]):
        text_encoding = self.tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=512)
        return text_encoding['input_ids'], text_encoding['attention_mask']

    def process_image(self, image_path: str | list[str]):
        if isinstance(image_path, str):
            image_path = [image_path]
        images = [Image.open(path).convert('RGB') for path in image_path]
        pixel_values = self.processor(images=images, return_tensors='pt')['pixel_values']
        return pixel_values

    def __getitem__(self, idx):
        if isinstance(idx, int):
            idx = [idx]
        text = self.data['text'].iloc[idx].to_list()
        image_path = (self.root + '/' + self.data['image_path'].iloc[idx]).to_list()
        labels = torch.tensor(self.data[self.label].iloc[idx].to_list(), dtype=torch.float)
        input_ids, attention_mask = self.process_text(text)
        pixel_values = self.process_image(image_path)
        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'pixel_values': pixel_values,
            'labels': labels
        }
    __getitems__ = __getitem__

In [114]:
class MultimodalModel(nn.Module):
    def __init__(self, config: Config):
        super(MultimodalModel, self).__init__()
        self.bert = BertModel.from_pretrained(config.bert_model_name)
        self.vit = ViTModel.from_pretrained(config.vit_model_name)

        # Congelando os parâmetros do BERT e do ViT
        for param in self.bert.parameters():
            param.requires_grad = False
        for param in self.vit.parameters():
            param.requires_grad = False
        hidden_size = self.bert.config.hidden_size + self.vit.config.hidden_size
        # Camada Fully Connected para classificação
        self.fc = nn.Linear(hidden_size, config.num_classes)

    def forward(self, input_ids, attention_mask, pixel_values):
        # Obtenção das saídas do BERT
        text_output = self.bert(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state[:, 0, :]

        # Obtenção das saídas do ViT
        image_output = self.vit(pixel_values=pixel_values).last_hidden_state[:, 0, :]

        # Concatenando as saídas do BERT e do ViT
        combined = torch.cat((text_output, image_output), dim=1)

        # Classificador final
        logits = self.fc(combined)

        return logits

In [115]:
tokenizer = BertTokenizer.from_pretrained(config.bert_model_name)
processor = ViTImageProcessor.from_pretrained(config.vit_model_name)

# Carregamento do Dataset
dataset = MultimodalDataset(
    csv_file='data/MEC/mec-dataset.csv',
    tokenizer=tokenizer,
    processor=processor,
    label='formal_register'
)
dataloader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True, collate_fn=lambda x: x)

In [None]:
model = MultimodalModel(config)
criterion = nn.MSELoss()  # Loss para regressão
optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)

In [None]:
model.train()
for epoch in range(config.num_epochs):
    running_loss = 0.0
    for inputs_ids, attention_mask, images, labels in dataloader:
        optimizer.zero_grad()

        # Forward pass
        outputs = model(inputs_ids, attention_mask, images)
        loss = criterion(outputs, labels)

        # Backward pass e otimização
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs_ids.size(0)

    epoch_loss = running_loss / len(dataloader.dataset)
    print(f"Epoch {epoch+1}/{config.num_epochs}, Loss: {epoch_loss:.4f}")