In [None]:
%matplotlib inline
import json
import logging
from pathlib import Path
import random
import tarfile
import tempfile
import warnings
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
# import pandas_path # Path style access for pandas
from tqdm import tqdm

In [None]:
import torch
import torchvision
# import fasttext

In [None]:
class LanguageAndVisionConcat(torch.nn.Module):
    def __init__(self,
                num_classes,
                loss_fn,
                dialog_module,
                transaction_module,
                language_feature_dim,
                vision_feature_dim,
                fusion_output_size,
                dropout_p,
                ):
        super(LanguageAndVisionConcat, self).__init__()
        self.dialog_module = dialog_module
        self.transaction_module = transaction_module

        self.fusion = torch.nn.Linear(in_features=(dialog_module + transaction_module),
                                      out_features=fusion_output_size
                                      )
        self.fc = torch.nn.Linear(in_features=fusion_output_size,
                                  out_features=num_classes
                                  )
        self.loss_fn = loss_fn
        self.dropout = torch.nn.Dropout(dropout_p)

    def forward(self, text, table, label=None):
        text_features = torch.nn.functional.relu(self.language_module(text))
        table_features = torch.nn.functional.relu(self.transaction_module(table))
        combined = torch.cat([text_features, table_features], dim=1)
        fused = self.dropout(torch.nn.functional.relu(self.fusion(combined)))

        logits = self.fc(fused)
        pred = torch.nn.functional.softmax(logits)
        loss = (self.loss_fn(pred, label) if label is not None else label)

        return (pred, loss, logits)


In [None]:
import torch.optim as optim
import torch.nn as nn

criterion = nn.CrossEntropyLoss()
model = LanguageAndVisionConcat(num_classes, image_module, dialog_module)
optimizer = optim.Adam(model.parameters(), lr=0.001)


In [None]:
# set up our data loaders and model
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

# train the model
num_epochs = 10
for epoch in range(num_epochs):
for i, batch in enumerate(train_loader):
images, texts, labels = batch
optimizer.zero_grad()
logits = model(images, texts)
loss = criterion(logits, labels)
loss.backward()
optimizer.step()
if i % 100 == 0:
print(f"Epoch {epoch+1}, batch {i+1}: loss {loss.item():.4f}")

In [None]:

# evaluate on validation set
total_correct = 0
total_samples = 0
with torch.no_grad():
for batch in val_loader:
images, texts, labels = batch
logits = model(images, texts)
predictions = torch.argmax(logits, dim=1)
total_correct += (predictions == labels).sum().item()
total_samples += len(labels)
accuracy = total_correct / total_samples
print(f"Epoch {epoch+1} validation accuracy: {accuracy:.4f}")