In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
import re
from io import StringIO
import os
from PIL import Image
import csv
from transformers import ViTModel, ViTFeatureExtractor, BertTokenizer, BertModel
from sklearn.metrics import f1_score, accuracy_score, recall_score
from tqdm import tqdm

IMAGE_PATH = './dataset/data/'
TRAINING_DATA_PATH = './dataset/train.csv'
TEST_DATA_PATH = './dataset/test.csv'
LABEL_COUNT = 19
BATCH_SIZE = 128
IMAGE_SHAPE = 224
LR = 0.0001
EPOCHS = 1
WEIGHT_DECAY = 0.001
THRESHOLD = 0.5

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('Using', device)
# Feature Extractor for ViT
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
print("Initialise vit and bert")

# Reading CSV Files (unchanged)
def read_csv(file_path):
    with open(file_path) as file:
        lines = [re.sub(r'([^,])"(\s*[^\n])', r'\1/"\2', line) for line in file]
        df = pd.read_csv(StringIO(''.join(lines)), escapechar="/")
    return df

df_train = read_csv(TRAINING_DATA_PATH)
df_test = read_csv(TEST_DATA_PATH)

# Dataset Class (modified for ViT)
class ImageTextDataset(Dataset):
    def __init__(self, df, image_path, processor, tokenizer, max_length=128, test=False):
        self.df = df
        self.image_path = image_path
        self.processor = processor
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.test = test

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_name = os.path.join(self.image_path, row['ImageID'])
        image = Image.open(img_name).convert('RGB')
        image = self.processor(images=image, return_tensors="pt")['pixel_values'].squeeze(0)
        
        caption = row['Caption']
        inputs = self.tokenizer(caption, max_length=self.max_length, padding='max_length', truncation=True, return_tensors="pt")
        input_ids = inputs['input_ids'].squeeze(0)
        attention_mask = inputs['attention_mask'].squeeze(0)
        
        if self.test :
            return image, input_ids, attention_mask


        labels = torch.zeros(LABEL_COUNT, dtype=torch.float32)
        label_indices = [int(l) for l in row['Labels'].split()]
        for label_index in label_indices:
            if label_index > 0 and label_index <= LABEL_COUNT:  # Check to ensure index is within the valid range
                labels[label_index - 1] = 1  # Correct indexing for Python (0-based)

        return image, input_ids, attention_mask, labels


# Datasets and DataLoaders (slightly adjusted)
train_dataset = ImageTextDataset(df_train, IMAGE_PATH, feature_extractor, tokenizer)
test_dataset = ImageTextDataset(df_test, IMAGE_PATH, feature_extractor, tokenizer, test=True)

# Splitting Dataset (unchanged)
train_size = int(0.8 * len(train_dataset))
val_size = len(train_dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

# Small Model ， you can try
class CombinedModel(nn.Module):
    def __init__(self, num_classes):
        super(CombinedModel, self).__init__()
        self.vit = ViTModel.from_pretrained('google/vit-base-patch16-224')
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        combined_dim = self.vit.config.hidden_size + self.bert.config.hidden_size
        self.classifier = nn.Linear(combined_dim, num_classes)

    def forward(self, images, input_ids, attention_mask):
        image_features = self.vit(pixel_values=images).pooler_output
        text_features = self.bert(input_ids=input_ids, attention_mask=attention_mask).pooler_output
        combined_features = torch.cat((image_features, text_features), dim=1)
        logits = self.classifier(combined_features)
        return logits


model = CombinedModel(num_classes=LABEL_COUNT).to(device)
torch.save(model.state_dict(), './model')

optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
criterion = nn.BCEWithLogitsLoss()
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

def train(model, train_loader, optimizer, criterion, scheduler):
    model.train()
    total_loss = 0
    progress_bar = tqdm(total=len(train_loader), desc='Training', leave=False)
    for images, input_ids, attention_mask, labels in train_loader:
        images = images.to(device)
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(images, input_ids, attention_mask)
        loss = criterion(outputs, labels.float())
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        progress_bar.update()

    progress_bar.close()
    scheduler.step()

    return total_loss / len(train_loader)

def validate(model, val_loader, criterion):
    model.eval()
    running_loss = 0.0
    all_preds = []
    all_targets = []
    with torch.no_grad():
        for images, input_ids, attention_mask, labels in val_loader:
            images = images.to(device)
            input_ids = input_ids.to(device)
            attention_mask = attention_mask.to(device)
            labels = labels.to(device)

            outputs = model(images, input_ids, attention_mask)
            loss = criterion(outputs, labels)
            running_loss += loss.item() * images.size(0)
            preds = torch.sigmoid(outputs).cpu().numpy()
            all_preds.extend(preds)
            all_targets.extend(labels.cpu().numpy())
    
    all_preds = np.array(all_preds)
    all_targets = np.array(all_targets)
    pred_labels = (all_preds > THRESHOLD).astype(int)
    
    f1 = f1_score(all_targets, pred_labels, average='micro')
    accuracy = accuracy_score(all_targets, pred_labels)
    recall = recall_score(all_targets, pred_labels, average='micro')
    return running_loss / len(val_loader.dataset), f1, accuracy, recall

print("Start training")
# Main training loop with progress bar
for epoch in range(EPOCHS):
    print(f'Epoch {epoch+1}/{EPOCHS}')
    train_loss = train(model, train_loader, optimizer, criterion, scheduler)
    val_loss, val_f1, val_acc, val_recall = validate(model, val_loader, criterion)
    tqdm.write(f'Epoch {epoch+1}: Train Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}, F1 Score: {val_f1:.4f}, Acc Score \
       : {val_acc}, Recall Score :{val_recall}')

print("End training")

# Testing Function
def predict(model, test_loader):
    model.eval()
    predictions = []
    with torch.no_grad():
        for images, input_ids, attention_mask in test_loader:
            images = images.to(device)
            input_ids = input_ids.to(device)
            attention_mask = attention_mask.to(device)

            outputs = model(images, input_ids, attention_mask)
            preds = torch.sigmoid(outputs).cpu().detach().numpy()
            binary_preds = (preds > THRESHOLD).astype(int)
            labels = [' '.join([str(i+1) for i, val in enumerate(pred) if val == 1]) for pred in binary_preds]
            predictions.extend(labels)
    return predictions

print("start prediction")
# Output Predictions
predictions = predict(model, test_loader)

# Export Predictions
output_file = 'predictions.csv'
with open(output_file, 'w', newline='') as csvfile:
    writer = csv.writer(csvfile)
    writer.writerow(['ImageID', 'Labels'])
    for idx, img_id in enumerate(df_test['ImageID']):
        writer.writerow([img_id, predictions[idx]])

print(f'Predictions exported to {output_file}.')