In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import torch.optim as optim

In [None]:
class CUB200Dataset(Dataset):
    def __init__(self, images_tensor, img_id_to_class_id, train_val_img_ids, class_id_to_class_name, transform=None):
        self.images = images_tensor
        self.labels = [img_id_to_class_id[train_val_img_ids[i]] for i in range(len(images_tensor))]
        self.label_names = [class_id_to_class_name[img_id_to_class_id[train_val_img_ids[i]]] for i in range(len(images_tensor))]
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]  # شکل (C, H, W) مثلاً (3, 64, 64)
        label = self.labels[idx]  # class id
        label_name = self.label_names[idx]
        
        # تبدیل به PIL برای اعمال transform
        image = transforms.ToPILImage()(image)
        
        if self.transform:
            image = self.transform(image)
        
        return image, label, label_name

In [None]:
resnet_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],  # مقادیر استاندارد ImageNet
        std=[0.229, 0.224, 0.225]
    )
])

In [None]:
data_images = torch.load('dataset/imgs_train_val_64x64.pth')

with torch.serialization.safe_globals([frozenset]):
    data_labels = torch.load("dataset/metadata.pth", weights_only=True)

img_ids = data_labels['train_val_img_ids']
num_total = len(img_ids)
num_train = int(0.8 * num_total)
num_val = num_total - num_train

img_id_to_class_id = data_labels['img_id_to_class_id']
class_id_to_class_name = data_labels['class_id_to_class_name']
dataset = CUB200Dataset(data_images, 
                        img_id_to_class_id,
                        img_ids,
                        class_id_to_class_name,
                        transform=resnet_transform)

dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers = 0)

In [None]:
for images, labels, _ in dataloader:
    print("Image batch shape:", images.shape)  # [B, 3, 224, 224]
    print("Label batch shape:", labels.shape)  # [B]
    break

In [None]:
# برای معکوس کردن نرمال‌سازی ImageNet
inv_normalize = transforms.Normalize(
    mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225],
    std=[1/0.229, 1/0.224, 1/0.225]
)

def imshow(img, title=None):
    img = inv_normalize(img)  # برگردوندن نرمال‌سازی
    img = torch.clamp(img, 0, 1)  # مقادیر رو بین 0 و 1 نگه می‌داریم
    npimg = img.permute(1, 2, 0).numpy()
    plt.imshow(npimg)
    if title:
        plt.title(title)
    plt.axis('off')
    plt.show()

def show_images_in_row(images, labels=None, class_names=None, num_images=5):
    plt.figure(figsize=(15, 3))
    for i in range(num_images):
        img = inv_normalize(images[i])
        img = torch.clamp(img, 0, 1)
        npimg = img.permute(1, 2, 0).numpy()

        plt.subplot(1, num_images, i + 1)
        plt.imshow(npimg)
        if labels is not None:
            label = labels[i].item()
            plt.title(f"Label: {class_names[i]}")
        plt.axis('off')
    plt.tight_layout()
    plt.show()

In [None]:

dataiter = iter(dataloader)
images, labels, labels_name = next(dataiter)
show_images_in_row(images, labels, class_names=labels_name, num_images=5)
print(labels_name[:5])
print(labels)

In [None]:
num_classes = data_labels['num_classes']
num_classes

In [None]:
# idx = 120
# lbl_tensor = data_labels['class_id_to_class_name'][data_labels['img_id_to_class_id'][data_labels['img_ids'][idx]]]

In [None]:
from torchvision.models import resnet50, ResNet50_Weights
import torch.nn as nn

resnet = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
for param in resnet.parameters():
    param.requires_grad = False

layers =  list(resnet.children())[:-2]
model = nn.Sequential(*layers)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

In [None]:
# fmap_size = 8
# features = 512

# class Cnn_Resnet_Without_Attention(nn.Module):
#     def __init__(self, *args, **kwargs):
#         super().__init__(*args, **kwargs)
#         self.feature = model
#         self.dropout = nn.Dropout(.5)
#         self.fc = nn.Linear(features, num_classes)
#         nn.init.xavier_normal_(self.fc.weight.data)
#         if self.fc.bias is not None:
#             torch.nn.init.constant_(self.fc.bias.data, val=0)
        
#     def forward(self, x):
#         bs = x.size()[0]

#         x = self.feature(x)
#         x = x.view(bs, features, fmap_size ** 2)
#         # Batch matrix multiplication
#         x = torch.bmm(x, torch.transpose(x, 1, 2))/ (fmap_size ** 2)
#         x = torch.sqrt(x + 1e-5)
#         x = self.dropout(x)
#         x = self.fc(x)

#         return x

In [None]:
from torch import FloatTensor
import torch.nn.functional as F

def new_parameter(*size):
    out = nn.Parameter(FloatTensor(*size), requires_grad=True)
    torch.nn.init.xavier_normal_(out)
    return out


class Attention(nn.Module):

    def __init__(self, attention_size):
        super(Attention, self).__init__()
        self.attention = new_parameter(attention_size, 1)

    def forward(self, x_in):
        # after this, we have (bs, feature_size, feature_size) with a diff weight per each cell
        attention_score = torch.matmul(x_in, self.attention).squeeze()
        attention_score = F.softmax(attention_score, dim=-1).view(x_in.size(0), x_in.size(1), 1)
        scored_x = x_in * attention_score

        # now, sum across dim 1 to get the expected feature vector
        condensed_x = torch.sum(scored_x, dim=1)

        return condensed_x

In [None]:
fmap_size = 7
features = 2048

class Cnn_Resnet(nn.Module):
    def __init__(self, attention = False, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.attention = attention
        self.feature = model
        self.dropout = nn.Dropout(.5)
        self.fc = nn.Linear(features, num_classes)
        self.attn = Attention(features)
        nn.init.xavier_normal_(self.fc.weight.data)
        if self.fc.bias is not None:
            torch.nn.init.constant_(self.fc.bias.data, val=0)
    
    def forward(self, x):
        bs = x.size()[0]

        x = self.feature(x)
        x = x.view(bs, features, fmap_size ** 2)
        # Batch matrix multiplication
        x = torch.bmm(x, torch.transpose(x, 1, 2))/ (fmap_size ** 2)
        x = torch.sqrt(x + 1e-5)
        
        if(self.attention == True):
            x = self.attn(x)
        print('x shape:', x.shape)
        x = self.dropout(x)
        x = self.fc(x)

        return x

In [None]:
def train(model, dataloader, criterion, optimizer, device, num_epochs=10):
    """
    Train the Cnn_Resnet_With_Attention model.

    Args:
        model (nn.Module): The model to train.
        dataloader (DataLoader): DataLoader for the training data.
        criterion (nn.Module): Loss function.
        optimizer (torch.optim.Optimizer): Optimizer for training.
        device (torch.device): Device to train on (CPU or GPU).
        num_epochs (int): Number of epochs to train for.

    Returns:
        None
    """
    model.train()  # Set the model to training mode

    for epoch in range(num_epochs):
        running_loss = 0.0
        correct_predictions = 0
        total_predictions = 0

        for images, labels, _ in dataloader:
            # Move data to the specified device
            images, labels = images.to(device), labels.to(device)

            # Zero the parameter gradients
            optimizer.zero_grad()

            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)

            # Backward pass and optimization
            loss.backward()
            optimizer.step()

            # Calculate running loss and accuracy
            running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            correct_predictions += (predicted == labels).sum().item()
            total_predictions += labels.size(0)

        # Print epoch statistics
        epoch_loss = running_loss / len(dataloader)
        epoch_accuracy = correct_predictions / total_predictions * 100
        print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {epoch_loss:.4f}, Accuracy: {epoch_accuracy:.2f}%")

    print("Training complete.")

In [None]:
def run_training(model):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    train(model, dataloader, criterion, optimizer, device, num_epochs=10)

In [None]:
cnn_model_without_attention = Cnn_Resnet()
cnn_model_without_attention.to(device)
run_training(cnn_model_without_attention)

In [None]:
cnn_model_with_attention = Cnn_Resnet(attention=True)
cnn_model_with_attention.to(device)
run_training(cnn_model_with_attention)