In [None]:
#importing
import torch
import torchvision
from optuna.terminator.improvement.emmr import torch
from torch.nn import Embedding
from torch.utils.data import DataLoader , Subset
from torchvision import datasets, transforms
from sklearn.model_selection import train_test_split
import pandas as pd
from PIL import Image
from torch.utils.data import Dataset
import random
import os
import torch.nn.functional as F



In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
Random_seed = 123
learning_rate = 0.0001
num_epochs = 10
batch_size = 32
Embedding = 128

In [None]:


class ArtBenchTriplet(Dataset):
    def __init__(self, csv_path, img_dir, split='train', transform=None):
        # 1. Load the metadata CSV
        self.df = pd.read_csv(csv_path)

        # 2. Filter by 'train' or 'test' using the column you mentioned
        self.df = self.df[self.df['split'] == split].reset_index(drop=True)

        self.img_dir = img_dir
        self.transform = transform

        # 3. Create a map of {label: [indices]} so we can find positives quickly
        # This makes sure 'Positive' is always the same style as 'Anchor'
        self.label_to_indices = self.df.groupby('label').groups

    def __getitem__(self, index):
        # --- GET ANCHOR ---
        row = self.df.iloc[index]
        label = row['label']
        anchor_path = os.path.join(self.img_dir, row['path'])
        anchor_img = Image.open(anchor_path).convert('RGB')

        # --- GET POSITIVE (Same style, different image) ---
        # We pick a random index from the same label group
        pos_indices = self.label_to_indices[label]
        pos_index = random.choice(pos_indices)
        # Ensure positive isn't the exact same image as anchor
        while pos_index == index:
            pos_index = random.choice(pos_indices)

        pos_row = self.df.iloc[pos_index]
        pos_path = os.path.join(self.img_dir, pos_row['path'])
        pos_img = Image.open(pos_path).convert('RGB')

        # --- GET NEGATIVE (Different style) ---
        # Pick a style that is NOT the current label
        all_labels = list(self.label_to_indices.keys())
        neg_label = random.choice([l for l in all_labels if l != label])

        neg_index = random.choice(self.label_to_indices[neg_label])
        neg_row = self.df.iloc[neg_index]
        neg_path = os.path.join(self.img_dir, neg_row['path'])
        neg_img = Image.open(neg_path).convert('RGB')

        # Apply transforms (Resize, Normalize, etc.)
        if self.transform:
            anchor_img = self.transform(anchor_img)
            pos_img = self.transform(pos_img)
            neg_img = self.transform(neg_img)

        return anchor_img, pos_img, neg_img

    def __len__(self):
        return len(self.df)

In [None]:
#Data Loading
train_transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((128,128)),
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)),
])
valid_transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((128,128)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)),
])


train_dataset = ArtBenchTriplet(
    csv_path=r'C:\Users\HP\PycharmProjects\PythonProject\Image-analysis\Style_dataset\ArtBench-10.csv',
    img_dir=r'C:\Users\HP\PycharmProjects\PythonProject\Image-analysis\Style_dataset\artbench-10-python\artbench-10-batches-py',
    split='train',
    transform=train_transform

)

valid_dataset = ArtBenchTriplet(
    csv_path=r'C:\Users\HP\PycharmProjects\PythonProject\Image-analysis\Style_dataset\ArtBench-10.csv',
    img_dir=r'C:\Users\HP\PycharmProjects\PythonProject\Image-analysis\Style_dataset\artbench-10-python\artbench-10-batches-py',
    split='test',
    transform=valid_transform
)

train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=4,
    pin_memory=True
)

valid_loader = DataLoader(
    valid_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=4,
    pin_memory=True
)



print(f"Train loader has {len(train_loader)} batches.")
print(f"Valid loader has {len(valid_loader)} batches.")

df = pd.read_csv(r'C:\Users\HP\PycharmProjects\PythonProject\Image-analysis\Style_dataset\ArtBench-10.csv')
print(df.columns)
print(df[['name', 'label', 'split']].head()) # Or whatever names you see


In [None]:
#Residual Block

class ResidualBlock(torch.nn.Module):
    def __init__(self,channels):
        super(ResidualBlock, self).__init__()
        self.block =  torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=channels[0], out_channels=channels[1], kernel_size=(3,3), stride=(2,2), padding=1),
            torch.nn.BatchNorm2d(channels[1]),
            torch.nn.ReLU(inplace=True),
              torch.nn.Conv2d(in_channels=channels[1],
                                out_channels=channels[2],
                                kernel_size=(1, 1),
                                stride=(1, 1),
                                padding=0),
                torch.nn.BatchNorm2d(channels[2])
        )

        self.shortcut = torch.nn.Sequential(
                torch.nn.Conv2d(in_channels=channels[0],
                                out_channels=channels[2],
                                kernel_size=(1, 1),
                                stride=(2, 2),
                                padding=0),
                torch.nn.BatchNorm2d(channels[2])
        )

    def forward(self,x):
        shortcut = x
        block = self.block(x)
        shortcut = self.shortcut(x)
        x = torch.nn.functional.relu(block + shortcut)
        return x


In [None]:
# Model

class ConvNet(torch.nn.Module):
    def __init__(self,Output):
        super(ConvNet, self).__init__()
        self.residual_block1 = ResidualBlock(channels = [3,32,64])
        self.residual_block2 = ResidualBlock(channels = [64,64,128])
        self.residual_block3 = ResidualBlock(channels = [128,128,256])
        self.residual_block4 = ResidualBlock(channels = [256,256,512])
        self.gap = torch.nn.AdaptiveAvgPool2d((1, 1))

        self.linear_1 = torch.nn.Linear(512,Output)

    def forward(self,x):
        x = self.residual_block1(x)
        x = self.residual_block2(x)
        x = self.residual_block3(x)
        x = self.residual_block4(x)
        x = self.gap(x)
        x = torch.flatten(x,1)
        logits = self.linear_1(x)
        logits = F.normalize(logits, p=2, dim=1)
        return logits


model = ConvNet(Output=Embedding)

model.to(device)
criterion = torch.nn.TripletMarginLoss(margin=1.0, p=2)
optimizer = torch.optim.AdamW(model.parameters(), lr= learning_rate,weight_decay=0.02)



In [None]:
def compute_accuracy(model, data_loader):
    model.eval()
    correct_pred, num_examples = 0, 0
    for i, (features, targets) in enumerate(data_loader):
        features = features.to(device)
        targets = targets.to(device)

        logits = model(features)
        # Get the index of the highest logit (this is the predicted class)
        _, predicted_labels = torch.max(logits, 1)

        num_examples += targets.size(0)
        correct_pred += (predicted_labels == targets).sum()

    return correct_pred.float() / num_examples * 100

In [None]:


for epoch in range(num_epochs):
    model.train() # Tell the model it's in training mode
    running_loss = 0.0

    for batch_idx, (anchors, positives, negatives) in enumerate(train_loader):
        # 1. Move to GPU/Device
        anchors, positives, negatives = anchors.to(device), positives.to(device), negatives.to(device)

        # 2. Forward pass
        emb_a = model(anchors)
        emb_p = model(positives)
        emb_n = model(negatives)

        # 3. Calculate Loss
        loss = criterion(emb_a, emb_p, emb_n)

        # 4. Backprop (The "Learning" Step)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # --- Statistics ---
        running_loss += loss.item()

        # Print every 50 batches
        if (batch_idx + 1) % 50 == 0:
            avg_loss = running_loss / 50
            print(f"Epoch [{epoch+1}/{num_epochs}] | Batch {batch_idx+1}/{len(train_loader)} | Loss: {avg_loss:.4f}")
            running_loss = 0.0

    # Optional: Print at the end of every Epoch
    print(f"==> Finished Epoch {epoch+1} <==\n")