In [21]:
import torch.nn as nn
from torchvision.models.resnet import resnet50
import json
import os
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from sklearn.metrics import accuracy_score, classification_report
from transformers import AutoModel, AutoTokenizer, get_scheduler
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler
from torch.optim import AdamW
from tqdm import tqdm, trange
from time import perf_counter
from PIL import Image
import pandas as pd



In [22]:
IMAGE_FOLDER = "/home/user/Desktop/RakutenMultimodalClassification/images/image_train/"

class ResNetDataset(Dataset):
    def __init__(self, df, label_to_id, train=False, text_field="designation", label_field="class", image_path_field="imageid", product_id_field='productid'):
        self.df = df.reset_index(drop=True)
        self.label_to_id = label_to_id
        self.train = train
        self.text_field = text_field
        self.label_field = label_field
        self.image_path_field = image_path_field
        self.product_id_field = product_id_field

        # ResNet-50 settings
        self.img_size = 224
        self.mean, self.std = (
            0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)


        self.train_transform_func = transforms.Compose(
                [transforms.RandomResizedCrop(self.img_size, scale=(0.5, 1.0)),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    transforms.Normalize(self.mean, self.std)
                    ])

        self.eval_transform_func = transforms.Compose(
                [transforms.Resize(256),
                    transforms.CenterCrop(self.img_size),
                    transforms.ToTensor(),
                    transforms.Normalize(self.mean, self.std)
                    ])

                    
    def __getitem__(self, index):
        text = str(self.df.at[index, self.text_field])
        label = self.label_to_id[self.df.at[index, self.label_field]]
        img_path = IMAGE_FOLDER + f"image_{self.df.at[index, self.image_path_field]}_product_{self.df.at[index, self.product_id_field]}.jpg"

        
        image = Image.open(img_path)
        if self.train:
          img = self.train_transform_func(image)
        else:
          img = self.eval_transform_func(image)

        return text, label, img

    def __len__(self):
        return self.df.shape[0]


In [23]:
class ResNetFeatureModel(nn.Module):
    def __init__(self, output_layer):
        super().__init__()
        self.output_layer = output_layer
        pretrained_resnet = resnet50(pretrained=True)
        self.children_list = []
        for n,c in pretrained_resnet.named_children():
            self.children_list.append(c)
            if n == self.output_layer:
                break

        self.net = nn.Sequential(*self.children_list)

        
    def forward(self,x):
        x = self.net(x)
        x = torch.flatten(x, 1)
        return x




class BertResNetModel(nn.Module):
    def __init__(self, num_labels, text_pretrained='jinaai/jina-embeddings-v3'):
        super().__init__()
        self.text_encoder = AutoModel.from_pretrained(text_pretrained, trust_remote_code=True)
        self.visual_encoder = ResNetFeatureModel(output_layer="avgpool")
        self.image_hidden_size = 2048
        self.classifier = nn.Linear(self.text_encoder.config.hidden_size + self.image_hidden_size, num_labels)

    def forward(self, text, image):
        text_feature = self.text_encoder.encode(text, convert_to_tensor=True)
        img_feature = self.visual_encoder(image)
        
        features = torch.cat((text_feature, img_feature), 1)
        logits = self.classifier(features)
        return logits


In [24]:
num_out_labels = 27
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

resnet_model = BertResNetModel(num_labels=num_out_labels)
resnet_model = resnet_model.to(device)


cuda


flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn i

In [25]:
from sklearn.model_selection import train_test_split
df = pd.read_csv("train.csv")
# take only 10% of data 
df_small = df.sample(frac=0.1)
df_train,df_val = train_test_split(df_small, test_size=0.2, stratify=df_small['class'])

label_to_id = {lab:i for i, lab in enumerate(df_train['class'].sort_values().unique())}
id_to_label = {v:k for k,v in label_to_id.items()}

In [26]:
batch_size = 32
learning_rate = 1e-5
weight_decay = 0.01
num_train_epochs = 300
warmup_steps = 1000
max_seq_length = 128
PATIENCE = 30

In [None]:
## training loop
#set_seed(seed_val)
from sklearn.metrics import classification_report, f1_score

train_dataset = ResNetDataset(df=df_train, label_to_id=label_to_id, train=True, text_field="designation", label_field="class", image_path_field="imageid", product_id_field='productid')
val_dataset = ResNetDataset(df=df_val, label_to_id=label_to_id, train=False, text_field="designation", label_field="class", image_path_field="imageid", product_id_field='productid')


train_sampler = RandomSampler(train_dataset)        
train_dataloader = DataLoader(dataset=train_dataset,
                    batch_size=batch_size, 
                    sampler=train_sampler)

val_dataloader = DataLoader(dataset=val_dataset, batch_size=batch_size)

t_total = len(train_dataloader) * num_train_epochs

optimizer = AdamW(resnet_model.parameters(), lr=learning_rate, weight_decay=weight_decay)
scheduler = get_scheduler(name="cosine", optimizer=optimizer, num_warmup_steps=warmup_steps, num_training_steps=t_total)

criterion = nn.CrossEntropyLoss().to(device)

start = perf_counter()
best_val_loss = float('inf')
early_stopping_counter = 0

for epoch_num in trange(num_train_epochs, desc='Epochs'):
    # Training
    resnet_model.train()
    epoch_total_loss = 0

    for step, batch in tqdm(enumerate(train_dataloader), total=len(train_dataloader), desc='Batch'):        
        b_text, b_labels, b_imgs = batch
        b_labels = b_labels.to(device)
        b_imgs = b_imgs.to(device)

        resnet_model.zero_grad()
        b_logits = resnet_model(text=b_text, image=b_imgs)
        
        loss = criterion(b_logits, b_labels)
        epoch_total_loss += loss.item()

        loss.backward()
        optimizer.step()
        scheduler.step()
        
    avg_loss = epoch_total_loss/len(train_dataloader)

    # Validation
    resnet_model.eval()
    val_preds = []
    val_labels = []
    val_total_loss = 0
    
    with torch.no_grad():
        for batch in val_dataloader:
            b_text, b_labels, b_imgs = batch
            b_labels = b_labels.to(device)
            b_imgs = b_imgs.to(device)
            
            b_logits = resnet_model(text=b_text, image=b_imgs)
            b_preds = torch.argmax(b_logits, dim=1).cpu().numpy()
            
            val_loss = criterion(b_logits, b_labels)
            val_total_loss += val_loss.item()
            
            val_preds.extend(b_preds)
            val_labels.extend(b_labels.cpu().numpy())
    
    avg_val_loss = val_total_loss/len(val_dataloader)
    f1 = f1_score(val_labels, val_preds, average='weighted')

    print('epoch =', epoch_num)
    print('    epoch_loss =', epoch_total_loss)
    print('    avg_epoch_loss =', avg_loss)
    print('    avg_val_loss =', avg_val_loss)
    print('    validation f1 weighted =', f1)
    print('    learning rate =', optimizer.param_groups[0]["lr"])

    # Early stopping check
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        early_stopping_counter = 0
    else:
        early_stopping_counter += 1
        if early_stopping_counter >= PATIENCE:
            print(f'Early stopping triggered after {epoch_num + 1} epochs')
            break

end = perf_counter()
resnet_training_time = end - start
print('Training completed in ', resnet_training_time, 'seconds')