In [1]:
import torch
import pandas as pd
import numpy as np
from tqdm import tqdm
import wandb

In [2]:
wandb.init(entity="ameyar3103-iiit-hyderabad",project="recurrent_conv_art", config={
    "epochs": 20,
    "batch_size": 64,
    "learning_rate": 0.001,
    "model": "RecurrentCNN"
})

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mameyar3103[0m ([33mameyar3103-iiit-hyderabad[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


## Data loading

In [3]:
df_train = pd.read_csv('wikiart_csv/genre_train.csv',header=None, names=["image_path", "genre_id"])
df_val = pd.read_csv('wikiart_csv/genre_val.csv',header=None, names=["image_path", "genre_id"])

In [4]:
# get the number of classes
num_classes = 10 # from genre_class.txt


In [5]:
# Gather input data
train_images = df_train['image_path'].values
train_labels = df_train['genre_id'].values

val_images = df_val['image_path'].values
val_labels = df_val['genre_id'].values

In [6]:
from torchvision import transforms
import cv2

## Preprocess data and create test and train dataset

In [7]:
# create test and train dataset for dataloader

def get_image(image_path,image_size=224):
    try:
        img = cv2.imread('./wikiart/' + image_path)
        if img is None:
            raise ValueError(f"Image not loaded: ./wikiart/{image_path}")
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        h, w, _ = img.shape
        scale = 256 / min(h, w)
        new_w = int(w * scale)
        new_h = int(h * scale)
        img_resized = cv2.resize(img, (new_w, new_h))
        start_x = (new_w - image_size) // 2
        start_y = (new_h - image_size) // 2
        img_cropped = img_resized[start_y:start_y+image_size, start_x:start_x+image_size]
        img_cropped = img_cropped.astype(np.float32) / 255.0
        img_tensor = torch.from_numpy(img_cropped).permute(2, 0, 1)
        mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
        std  = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
        img_tensor = (img_tensor - mean) / std
        return img_tensor
    except Exception as e:
        print(f"Error processing {image_path}: {e}")
        return torch.zeros(3, image_size, image_size)

class WikiArtDataset(torch.utils.data.Dataset):
    def __init__(self, images, labels):
        self.images = images
        self.labels = labels

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        # image_vectors = []
        # for image in self.images:
        #     image_emb = get_image(image)
        #     image_vectors.append(image_emb)
        # image = torch.stack(image_vectors)
        image = self.images[idx]
        # label should be a one-hot encoded vector
        label = torch.zeros(num_classes)
        label[self.labels[idx]] = 1

        return image, label

train_dataset = WikiArtDataset(train_images, train_labels)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
val_dataset = WikiArtDataset(val_images, val_labels)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=64, shuffle=False)

for i, (images, labels) in enumerate(train_loader):
    print(images)
    print(labels)
    break

('Impressionism/robert-julian-onderdonk_goat-herder-at-the-san-antonio-quarry-1909.jpg', 'Impressionism/arthur-verona_neagoe-basarab-study.jpg', 'Early_Renaissance/paolo-uccello_st-francis.jpg', 'High_Renaissance/michelangelo_the-ancestors-of-christ-manasseh-amon-1512.jpg', 'Baroque/adriaen-brouwer_inn-with-drunken-peasants.jpg', 'Realism/vincent-van-gogh_farmhouses-in-loosduinen-near-the-hague-at-twilight-1883(1).jpg', 'Romanticism/jan-matejko_jadwiga.jpg', 'Post_Impressionism/bertalan-por_brookside-1919.jpg', 'Realism/ivan-shishkin_fir.jpg', 'Impressionism/camille-pissarro_landscape-with-a-man-digging-1877.jpg', 'Romanticism/dante-gabriel-rossetti_study-for-a-vision-of-fiammetta.jpg', 'Expressionism/martiros-saryan_gohtan-mountains-1914.jpg', 'Realism/vasily-vereshchagin_parsi-priest-fire-worshiper-bombay-1876.jpg', 'Realism/klavdy-lebedev_spat-on-the-terrace.jpg', 'Impressionism/pierre-auguste-renoir_young-girl-in-a-flowered-hat-1905.jpg', 'Realism/johan-hendrik-weissenbruch_figures

In [None]:
import torch.nn as nn
import torch.nn.functional as F

class RecurrentCNN(nn.Module):
    def __init__(self, num_classes, lstm_hidden_size=256, dropout_prob=0.5):
        super(RecurrentCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.pool2 = nn.MaxPool2d(2, 2)
        self.adaptive_pool = nn.AdaptiveAvgPool2d((14, 56))
        self.lstm_input_size = 64 * 56
        self.lstm_hidden_size = lstm_hidden_size
        self.lstm = nn.LSTM(input_size=self.lstm_input_size, hidden_size=lstm_hidden_size,
                            batch_first=True, bidirectional=True)
        self.dropout = nn.Dropout(dropout_prob)
        self.fc = nn.Linear(2 * lstm_hidden_size, num_classes)

    def forward(self, x):
        x = F.relu(self.conv1(x))      
        x = self.pool1(x)              
        x = F.relu(self.conv2(x))      
        x = self.pool2(x)              
        x = self.adaptive_pool(x)      
        x = x.permute(0, 2, 1, 3).contiguous()  
        batch_size, seq_len, channels, width = x.shape  
        x = x.view(batch_size, seq_len, channels * width)  
        lstm_out, _ = self.lstm(x)  
        x = lstm_out.mean(dim=1)    
        x = self.dropout(x)
        x = self.fc(x)
        return x
    
model = RecurrentCNN(num_classes)
model.to('cuda')

# Loss and optimizer
import torch.optim as optim

wandb.watch(model, log="all")
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

## Training the model

In [9]:
# Train the model
num_epochs = 20

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    train_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")
    for image_paths, labels in train_bar:
        image_tensors = torch.stack([get_image(image_path) for image_path in image_paths])
        images = image_tensors.to('cuda')
        labels = labels.to('cuda')
        
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        train_bar.set_postfix(loss=loss.item())
    
    avg_train_loss = running_loss / len(train_loader)
    wandb.log({"epoch": epoch+1, "train_loss": avg_train_loss})
    
    # Validation Loop
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        val_bar = tqdm(val_loader, desc="Validation")
        for image_paths, labels in val_bar:
            image_tensors = torch.stack([get_image(image_path) for image_path in image_paths])
            image_tensors = image_tensors.to('cuda')
            labels = labels.to('cuda')
            outputs = model(image_tensors)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels.argmax(dim=1)).sum().item()
            val_bar.set_postfix(loss=loss.item())
    
    avg_val_loss = val_loss / len(val_loader)
    val_accuracy = 100 * correct / total
    wandb.log({"val_loss": avg_val_loss, "val_accuracy": val_accuracy})
    print(f"Epoch {epoch+1}/{num_epochs} - Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}, Val Accuracy: {val_accuracy:.2f}%")
    if(epoch%5==0):
        torch.save(model.state_dict(), f"recurrent_cnn_epoch_{epoch}_genre.pth")
        torch.save(optimizer.state_dict(), f"recurrent_cnn_optimizer_epoch_{epoch}_genre.pth")

Epoch 1/20:  53%|█████▎    | 374/711 [06:07<05:21,  1.05it/s, loss=1.49]Corrupt JPEG data: premature end of data segment
Epoch 1/20:  98%|█████████▊| 696/711 [11:19<00:12,  1.16it/s, loss=1.68]Corrupt JPEG data: bad Huffman code
Epoch 1/20: 100%|██████████| 711/711 [11:34<00:00,  1.02it/s, loss=1.7] 
Validation: 100%|██████████| 305/305 [04:29<00:00,  1.13it/s, loss=1.2]  


Epoch 1/20 - Train Loss: 1.6431, Val Loss: 1.4642, Val Accuracy: 49.14%


Epoch 2/20:  62%|██████▏   | 438/711 [07:15<03:51,  1.18it/s, loss=1.45]Corrupt JPEG data: premature end of data segment
Epoch 2/20:  72%|███████▏  | 512/711 [08:25<03:08,  1.06it/s, loss=1.74]Corrupt JPEG data: bad Huffman code
Epoch 2/20: 100%|██████████| 711/711 [11:41<00:00,  1.01it/s, loss=1.34]
Validation: 100%|██████████| 305/305 [04:35<00:00,  1.11it/s, loss=0.956]


Epoch 2/20 - Train Loss: 1.4466, Val Loss: 1.4048, Val Accuracy: 50.63%


Epoch 3/20:  26%|██▌       | 185/711 [03:08<08:27,  1.04it/s, loss=1.22]Corrupt JPEG data: bad Huffman code
Epoch 3/20:  54%|█████▍    | 385/711 [06:19<04:57,  1.10it/s, loss=1.48]Corrupt JPEG data: premature end of data segment
Epoch 3/20: 100%|██████████| 711/711 [11:28<00:00,  1.03it/s, loss=1.43] 
Validation: 100%|██████████| 305/305 [04:28<00:00,  1.14it/s, loss=1.27] 


Epoch 3/20 - Train Loss: 1.3456, Val Loss: 1.3164, Val Accuracy: 54.04%


Epoch 4/20:  65%|██████▌   | 463/711 [07:16<04:06,  1.01it/s, loss=1.2]  Corrupt JPEG data: bad Huffman code
Epoch 4/20:  83%|████████▎ | 591/711 [09:19<01:47,  1.11it/s, loss=1.46] Corrupt JPEG data: premature end of data segment
Epoch 4/20: 100%|██████████| 711/711 [11:16<00:00,  1.05it/s, loss=1.2]  
Validation: 100%|██████████| 305/305 [04:26<00:00,  1.14it/s, loss=1.36] 


Epoch 4/20 - Train Loss: 1.2404, Val Loss: 1.2897, Val Accuracy: 55.05%


Epoch 5/20:  31%|███       | 217/711 [03:24<08:41,  1.06s/it, loss=0.926]Corrupt JPEG data: bad Huffman code
Epoch 5/20:  47%|████▋     | 336/711 [05:20<05:55,  1.06it/s, loss=1.44] Corrupt JPEG data: premature end of data segment
Epoch 5/20: 100%|██████████| 711/711 [11:17<00:00,  1.05it/s, loss=1.04] 
Validation: 100%|██████████| 305/305 [04:28<00:00,  1.14it/s, loss=1.28] 


Epoch 5/20 - Train Loss: 1.1248, Val Loss: 1.2582, Val Accuracy: 56.26%


Epoch 6/20:  31%|███       | 218/711 [03:27<07:56,  1.03it/s, loss=0.926]Corrupt JPEG data: premature end of data segment
Epoch 6/20:  35%|███▍      | 246/711 [03:54<07:28,  1.04it/s, loss=0.839]Corrupt JPEG data: bad Huffman code
Epoch 6/20: 100%|██████████| 711/711 [11:14<00:00,  1.05it/s, loss=1.03] 
Validation: 100%|██████████| 305/305 [04:26<00:00,  1.15it/s, loss=1.63] 


Epoch 6/20 - Train Loss: 0.9760, Val Loss: 1.2952, Val Accuracy: 56.36%


Epoch 7/20:  43%|████▎     | 305/711 [04:49<06:21,  1.06it/s, loss=0.529]Corrupt JPEG data: bad Huffman code
Epoch 7/20:  56%|█████▌    | 399/711 [06:17<04:40,  1.11it/s, loss=0.841]Corrupt JPEG data: premature end of data segment
Epoch 7/20: 100%|██████████| 711/711 [11:12<00:00,  1.06it/s, loss=0.709]
Validation: 100%|██████████| 305/305 [04:29<00:00,  1.13it/s, loss=1.55] 


Epoch 7/20 - Train Loss: 0.7983, Val Loss: 1.3486, Val Accuracy: 56.16%


Epoch 8/20:   0%|          | 3/711 [00:03<12:03,  1.02s/it, loss=0.813]Corrupt JPEG data: premature end of data segment
Epoch 8/20:  38%|███▊      | 269/711 [04:20<07:08,  1.03it/s, loss=0.511]


KeyboardInterrupt: 

Error in callback <bound method _WandbInit._pause_backend of <wandb.sdk.wandb_init._WandbInit object at 0x744d1bc5e4d0>> (for post_run_cell):


BrokenPipeError: [Errno 32] Broken pipe