## Import

In [1]:
import torch
import torch.nn as nn
import pandas as pd
import timm
from MANIQA import *
from MANIQA2Transformer import *
import torch.nn.utils as F
import torchvision.transforms as transforms
import torchvision.models as models
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from tqdm import tqdm
import math
import os
import numpy as np
import random
import warnings
warnings.filterwarnings(action='ignore') 

## Hyperparameter Settings

In [3]:
CFG = {
    'IMG_SIZE':224,
    'EPOCHS': 100, #Your Epochs,
    'LR': 3e-4, #Your Learning Rate,
    'BATCH_SIZE': 128, #Your Batch Size,
    'SEED':41
}

## Fixed Random-Seed

In [3]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed_everything(CFG['SEED']) 

## Custom Dataset

In [4]:
class CustomDataset(Dataset):
    def __init__(self, dataframe):
        self.dataframe = dataframe
        self.transform = transforms.Compose([
            transforms.Resize((224,224)),
            transforms.RandAugment(),
            transforms.ToTensor(),
            transforms.RandomErasing(),
            transforms.Normalize(mean=[0.5,0.5,0.5],
                        std=[0.5,0.5,0.5])
        ])
        
    def __len__(self):
        return len(self.dataframe)
    
    def __getitem__(self, idx):
        img_path = self.dataframe.iloc[idx]['img_path']
        img = Image.open(img_path).convert('RGB')
        
        img = self.transform(img)
        
        # mos column 존재 여부에 따라 값을 설정
        mos = float(self.dataframe.iloc[idx]['mos']) if 'mos' in self.dataframe.columns else 0.0
        comment = self.dataframe.iloc[idx]['comments'] if 'comments' in self.dataframe.columns else ""

        
        return img, mos, comment

In [5]:
train_df = pd.read_csv('train.csv')
train_df

Unnamed: 0,img_name,img_path,mos,comments
0,41wy7upxzl,./train/41wy7upxzl.jpg,5.569231,the pink and blue really compliment each other...
1,ygujjq6xxt,./train/ygujjq6xxt.jpg,6.103175,love rhubarb! great colors!
2,wk321130q0,./train/wk321130q0.jpg,5.541985,i enjoy the textures and grungy feel to this. ...
3,w50dp2zjpg,./train/w50dp2zjpg.jpg,6.234848,"i like all the different colours in this pic, ..."
4,l7rqfxeuh0,./train/l7rqfxeuh0.jpg,5.190476,"i love these critters, just wish he was a litt..."
...,...,...,...,...
74563,zbevd0lyox,./train/zbevd0lyox.jpg,5.926108,"perfect balance here, in this soft serene image."
74564,w26yu6ee60,./train/w26yu6ee60.jpg,5.966346,very nice indeed. the sharpness and contrast a...
74565,a1pts9zzdx,./train/a1pts9zzdx.jpg,5.718447,nice tones and color for balance.
74566,pzbubeo03l,./train/pzbubeo03l.jpg,6.007843,i like the bold colors. nice sharp image.


## Define Model

In [2]:
example_vocab = 25000
model = MANIQA2transformer(example_vocab)
model

MANIQA2transformer(
  (cnn): MANIQA(
    (vit): VisionTransformer(
      (patch_embed): PatchEmbed(
        (proj): Conv2d(3, 768, kernel_size=(8, 8), stride=(8, 8))
        (norm): Identity()
      )
      (pos_drop): Dropout(p=0.0, inplace=False)
      (patch_drop): Identity()
      (norm_pre): Identity()
      (blocks): Sequential(
        (0): Block(
          (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (attn): Attention(
            (qkv): Linear(in_features=768, out_features=2304, bias=True)
            (q_norm): Identity()
            (k_norm): Identity()
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=768, out_features=768, bias=True)
            (proj_drop): Dropout(p=0.0, inplace=False)
          )
          (ls1): Identity()
          (drop_path1): Identity()
          (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=768, out_fe

In [14]:
import matplotlib.pyplot as plt
def show_image(img, title=None):
    """Imshow for Tensor."""
    
    #unnormalize 
    img[0] = img[0] * 0.229
    img[1] = img[1] * 0.224 
    img[2] = img[2] * 0.225 
    img[0] += 0.485 
    img[1] += 0.456 
    img[2] += 0.406
    
    img = img.detach().cpu().numpy().transpose((1, 2, 0))
    
    
    plt.imshow(img)
    if title is not None:
        plt.title(title)
    plt.pause(0.001)  # pause a bit so that plots are updated

In [15]:
def greedy_decode(model, image):
    image = image.unsqueeze(0).cuda()
    mos, _ = model(image)
    return mos.item(), _

## Train

In [16]:
from sklearn.model_selection import train_test_split
x_train, x_valid, y_train, y_valid = train_test_split(train_df, train_df.mos, test_size=0.2,shuffle=True,random_state=True)
print(x_train.shape, x_valid.shape, y_train.shape, y_valid.shape)

(59654, 4) (14914, 4) (59654,) (14914,)


In [17]:
# 데이터 로드
train_data = pd.read_csv('train.csv')

# 단어 사전 생성
all_comments = ' '.join(train_data['comments']).split()
vocab = set(all_comments)
vocab = ['<PAD>', '<SOS>', '<EOS>'] + list(vocab)
word2idx = {word: idx for idx, word in enumerate(vocab)}
idx2word = {idx: word for word, idx in word2idx.items()}

train_dataset = CustomDataset(x_train)
valid_dataset = CustomDataset(x_valid)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=16, shuffle=False)

model = MANIQA2transformer(len(vocab))


model.cuda()
criterion1 = nn.MSELoss()
criterion2 = nn.CrossEntropyLoss(ignore_index=word2idx['<PAD>'])
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5,weight_decay=1e-2)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100,eta_min=0)
print_every = 100

# 학습
model.train()
valid_min_loss = np.Inf
for epoch in range(100):
    train_loss = 0
    train_mos_loss = 0
    train_cm_loss = 0
    valid_loss = 0
    valid_mos_loss = 0
    valid_cm_loss = 0
    train_loop = tqdm(train_loader, leave=True)
    for imgs, mos, comments in train_loop:
        imgs, mos = imgs.float().cuda(), mos.float().cuda()
        
        # Batch Preprocessing
        src_tensor = torch.zeros((len(comments), len(max(comments, key=len)))).long().cuda()
        for i, comment in enumerate(comments):
            tokenized = ['<SOS>'] + comment.split() + ['<EOS>']
            src_tensor[i, :len(tokenized)] = torch.tensor([word2idx[word] for word in tokenized])
        
        
        tgt = src_tensor[:,1:]
        sequence_len = src_tensor[:,1:].size(1)
        tgt_mask = model.get_tgt_mask(sequence_len).cuda()
        # Forward & Loss
        predicted_mos, predicted_comments = model(imgs, src_tensor[:,:-1],tgt, tgt_mask)
        loss1 = criterion1(predicted_mos, mos)
        loss2 = criterion2(predicted_comments.view(-1, len(vocab)), tgt.reshape(-1))
        loss = loss1 + loss2

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        train_mos_loss += loss1.item()
        train_cm_loss += loss2.item()
        train_loop.set_description(f"Epoch {epoch + 1}")
        train_loop.set_postfix(loss=loss.item(),mos_loss=loss1.item(),cm_loss=loss2.item())

    valid_loop = tqdm(valid_loader, leave=True)

    model.eval()
    with torch.no_grad():
        for imgs, mos, comments in valid_loop:
            imgs, mos = imgs.float().cuda(), mos.float().cuda()
                
            # Batch Preprocessing
            src_tensor = torch.zeros((len(comments), len(max(comments, key=len)))).long().cuda()
            for i, comment in enumerate(comments):
                tokenized = ['<SOS>'] + comment.split() + ['<EOS>']
                src_tensor[i, :len(tokenized)] = torch.tensor([word2idx[word] for word in tokenized])
                
                
            tgt = src_tensor[:,1:]
            sequence_len = src_tensor[:,1:].size(1)
            tgt_mask = model.get_tgt_mask(sequence_len).cuda()
            # Forward & Loss
            predicted_mos, predicted_comments = model(imgs, src_tensor[:,:-1],tgt, tgt_mask)
            loss1 = criterion1(predicted_mos, mos)
            loss2 = criterion2(predicted_comments.view(-1, len(vocab)), tgt.reshape(-1))
            loss = loss1 + loss2

            valid_loss += loss.item()
            valid_mos_loss += loss1.item()
            valid_cm_loss += loss2.item()
            valid_loop.set_description(f"Epoch {epoch + 1}")
            valid_loop.set_postfix(loss=loss.item(),mos_loss=loss1.item(),cm_loss=loss2.item())



    print(f"Epoch {epoch + 1} total train loss: {train_loss / len(train_loader):.4f} valid loss: {valid_loss / len(valid_loader):.4f}")
    print(f"Epoch {epoch + 1} Mos train loss: {train_mos_loss / len(train_loader):.4f} valid loss: {valid_mos_loss / len(valid_loader):.4f}")
    print(f"Epoch {epoch + 1} comments train loss: {train_cm_loss / len(train_loader):.4f} valid loss: {valid_cm_loss / len(valid_loader):.4f}")
    if valid_min_loss > valid_loss / len(valid_loader):
       print('valid loss decreased {:.4f} ---> {:.4f}'.format(valid_min_loss,(valid_loss/len(valid_loader))))
       valid_min_loss = valid_loss / len(valid_loader)
       torch.save(model.state_dict(),'MANIAQA+transformer.pt')

Epoch 1: 100%|██████████| 7457/7457 [1:12:39<00:00,  1.71it/s, cm_loss=2.35, loss=3.01, mos_loss=0.659] 
Epoch 1: 100%|██████████| 933/933 [06:43<00:00,  2.31it/s, cm_loss=2.32, loss=2.75, mos_loss=0.429] 


Epoch 1 total train loss: 4.9592 valid loss: 2.5877
Epoch 1 Mos train loss: 1.0328 valid loss: 0.4086
Epoch 1 comments train loss: 3.9263 valid loss: 2.1791
valid loss decreased inf ---> 2.5877


Epoch 2: 100%|██████████| 7457/7457 [1:12:28<00:00,  1.71it/s, cm_loss=1.06, loss=1.44, mos_loss=0.381]   
Epoch 2: 100%|██████████| 933/933 [06:50<00:00,  2.27it/s, cm_loss=1.11, loss=1.73, mos_loss=0.615]  


Epoch 2 total train loss: 1.7378 valid loss: 1.4893
Epoch 2 Mos train loss: 0.3586 valid loss: 0.3758
Epoch 2 comments train loss: 1.3792 valid loss: 1.1136
valid loss decreased 2.5877 ---> 1.4893


Epoch 3: 100%|██████████| 7457/7457 [1:12:21<00:00,  1.72it/s, cm_loss=0.933, loss=1.07, mos_loss=0.138]  
Epoch 3: 100%|██████████| 933/933 [06:47<00:00,  2.29it/s, cm_loss=0.807, loss=0.914, mos_loss=0.107] 


Epoch 3 total train loss: 1.0787 valid loss: 1.1137
Epoch 3 Mos train loss: 0.2407 valid loss: 0.2743
Epoch 3 comments train loss: 0.8381 valid loss: 0.8394
valid loss decreased 1.4893 ---> 1.1137


Epoch 4: 100%|██████████| 7457/7457 [1:11:50<00:00,  1.73it/s, cm_loss=0.974, loss=1.05, mos_loss=0.0809] 
Epoch 4: 100%|██████████| 933/933 [06:41<00:00,  2.32it/s, cm_loss=0.728, loss=0.79, mos_loss=0.0617] 


Epoch 4 total train loss: 0.7855 valid loss: 0.9557
Epoch 4 Mos train loss: 0.1556 valid loss: 0.2533
Epoch 4 comments train loss: 0.6298 valid loss: 0.7024
valid loss decreased 1.1137 ---> 0.9557


Epoch 5: 100%|██████████| 7457/7457 [1:11:54<00:00,  1.73it/s, cm_loss=0.369, loss=0.417, mos_loss=0.0477]   
Epoch 5: 100%|██████████| 933/933 [06:46<00:00,  2.30it/s, cm_loss=0.687, loss=0.928, mos_loss=0.241] 


Epoch 5 total train loss: 0.6161 valid loss: 0.8729
Epoch 5 Mos train loss: 0.1104 valid loss: 0.2476
Epoch 5 comments train loss: 0.5057 valid loss: 0.6253
valid loss decreased 0.9557 ---> 0.8729


Epoch 6: 100%|██████████| 7457/7457 [1:12:16<00:00,  1.72it/s, cm_loss=0.286, loss=0.335, mos_loss=0.049]   
Epoch 6: 100%|██████████| 933/933 [06:46<00:00,  2.30it/s, cm_loss=0.64, loss=0.718, mos_loss=0.078]  


Epoch 6 total train loss: 0.5039 valid loss: 0.8134
Epoch 6 Mos train loss: 0.0843 valid loss: 0.2366
Epoch 6 comments train loss: 0.4195 valid loss: 0.5768
valid loss decreased 0.8729 ---> 0.8134


Epoch 7: 100%|██████████| 7457/7457 [1:13:00<00:00,  1.70it/s, cm_loss=0.251, loss=0.349, mos_loss=0.0983]    
Epoch 7: 100%|██████████| 933/933 [06:54<00:00,  2.25it/s, cm_loss=0.606, loss=0.626, mos_loss=0.0206]


Epoch 7 total train loss: 0.4235 valid loss: 0.7746
Epoch 7 Mos train loss: 0.0710 valid loss: 0.2258
Epoch 7 comments train loss: 0.3525 valid loss: 0.5488
valid loss decreased 0.8134 ---> 0.7746


Epoch 8: 100%|██████████| 7457/7457 [1:12:43<00:00,  1.71it/s, cm_loss=0.578, loss=0.654, mos_loss=0.076]     
Epoch 8: 100%|██████████| 933/933 [06:46<00:00,  2.30it/s, cm_loss=0.575, loss=0.579, mos_loss=0.00427]


Epoch 8 total train loss: 0.3576 valid loss: 0.7464
Epoch 8 Mos train loss: 0.0609 valid loss: 0.2267
Epoch 8 comments train loss: 0.2967 valid loss: 0.5197
valid loss decreased 0.7746 ---> 0.7464


Epoch 9: 100%|██████████| 7457/7457 [1:12:03<00:00,  1.72it/s, cm_loss=0.363, loss=0.383, mos_loss=0.0206]    
Epoch 9: 100%|██████████| 933/933 [06:45<00:00,  2.30it/s, cm_loss=0.553, loss=0.558, mos_loss=0.0053]


Epoch 9 total train loss: 0.3034 valid loss: 0.7314
Epoch 9 Mos train loss: 0.0550 valid loss: 0.2227
Epoch 9 comments train loss: 0.2484 valid loss: 0.5087
valid loss decreased 0.7464 ---> 0.7314


Epoch 10: 100%|██████████| 7457/7457 [1:11:19<00:00,  1.74it/s, cm_loss=0.138, loss=0.177, mos_loss=0.0388]    
Epoch 10: 100%|██████████| 933/933 [06:42<00:00,  2.32it/s, cm_loss=0.523, loss=0.525, mos_loss=0.00168]


Epoch 10 total train loss: 0.2563 valid loss: 0.7301
Epoch 10 Mos train loss: 0.0499 valid loss: 0.2256
Epoch 10 comments train loss: 0.2064 valid loss: 0.5045
valid loss decreased 0.7314 ---> 0.7301


Epoch 11: 100%|██████████| 7457/7457 [1:11:11<00:00,  1.75it/s, cm_loss=0.174, loss=0.192, mos_loss=0.018]     
Epoch 11: 100%|██████████| 933/933 [06:41<00:00,  2.32it/s, cm_loss=0.506, loss=0.566, mos_loss=0.06]   


Epoch 11 total train loss: 0.2136 valid loss: 0.7058
Epoch 11 Mos train loss: 0.0447 valid loss: 0.2182
Epoch 11 comments train loss: 0.1689 valid loss: 0.4876
valid loss decreased 0.7301 ---> 0.7058


Epoch 12: 100%|██████████| 7457/7457 [1:11:33<00:00,  1.74it/s, cm_loss=0.323, loss=0.364, mos_loss=0.0407]    
Epoch 12: 100%|██████████| 933/933 [06:45<00:00,  2.30it/s, cm_loss=0.463, loss=0.491, mos_loss=0.0275] 


Epoch 12 total train loss: 0.1769 valid loss: 0.6950
Epoch 12 Mos train loss: 0.0420 valid loss: 0.2302
Epoch 12 comments train loss: 0.1350 valid loss: 0.4649
valid loss decreased 0.7058 ---> 0.6950


Epoch 13: 100%|██████████| 7457/7457 [1:12:05<00:00,  1.72it/s, cm_loss=0.0464, loss=0.0939, mos_loss=0.0475]   
Epoch 13: 100%|██████████| 933/933 [06:45<00:00,  2.30it/s, cm_loss=0.477, loss=0.481, mos_loss=0.00426]


Epoch 13 total train loss: 0.1438 valid loss: 0.6760
Epoch 13 Mos train loss: 0.0390 valid loss: 0.2140
Epoch 13 comments train loss: 0.1048 valid loss: 0.4620
valid loss decreased 0.6950 ---> 0.6760


Epoch 14: 100%|██████████| 7457/7457 [1:12:02<00:00,  1.72it/s, cm_loss=0.0304, loss=0.0598, mos_loss=0.0294]    
Epoch 14: 100%|██████████| 933/933 [06:46<00:00,  2.30it/s, cm_loss=0.393, loss=0.41, mos_loss=0.0165]  


Epoch 14 total train loss: 0.1141 valid loss: 0.6718
Epoch 14 Mos train loss: 0.0359 valid loss: 0.2175
Epoch 14 comments train loss: 0.0782 valid loss: 0.4543
valid loss decreased 0.6760 ---> 0.6718


Epoch 15: 100%|██████████| 7457/7457 [1:12:06<00:00,  1.72it/s, cm_loss=0.13, loss=0.152, mos_loss=0.0216]       
Epoch 15: 100%|██████████| 933/933 [06:45<00:00,  2.30it/s, cm_loss=0.411, loss=0.417, mos_loss=0.00532]


Epoch 15 total train loss: 0.0897 valid loss: 0.6624
Epoch 15 Mos train loss: 0.0343 valid loss: 0.2131
Epoch 15 comments train loss: 0.0555 valid loss: 0.4493
valid loss decreased 0.6718 ---> 0.6624


Epoch 16:  59%|█████▉    | 4396/7457 [42:30<29:36,  1.72it/s, cm_loss=0.0891, loss=0.178, mos_loss=0.0885]     


KeyboardInterrupt: 

## Inference & Submit

In [1]:
model = BaseModel(vocab_size=len(vocab)).cuda()
model.load_state_dict(torch.load('MANIQA+transformer.pt'))
test_data = pd.read_csv('test.csv')
test_dataset = CustomDataset(test_data)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
predicted_mos_list = []
predicted_comments_list = []

def greedy_decode(model, image):
    image = image.unsqueeze(0).cuda()
    mos, _ = model(image)
    return mos.item(), _

model.eval()
with torch.no_grad():
    for imgs, _, _ in tqdm(test_loader):
        for img in imgs:
            img = img.float().cuda()
            mos, temp = greedy_decode(model, img)

            features = model.cnn(img.unsqueeze(0))

            caps= model.generate_caption(features)
            caption = ' '.join(caps)
            predicted_mos_list.append(mos)
            predicted_comments_list.append(caption)
            # show_image(imgs[0],title=caption)

# 결과 저장
result_df = pd.DataFrame({
    'img_name': test_data['img_name'],
    'mos': predicted_mos_list,
    'comments': predicted_comments_list  # 캡션 부분은 위에서 생성한 것을 사용
})

# 예측 결과에 NaN이 있다면, 제출 시 오류가 발생하므로 후처리 진행 (sample_submission.csv과 동일하게)
result_df['comments'] = result_df['comments'].fillna('Nice Image.')
result_df.to_csv('submit.csv', index=False)

print("Inference completed and results saved to submit.csv.")