In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split

from tqdm import tqdm

import images and make dataloader

In [2]:
from datasets import load_dataset

ds = load_dataset("timm/resisc45")

Downloading readme:   0%|          | 0.00/3.11k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/255M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/85.1M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/85.2M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/18900 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/6300 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/6300 [00:00<?, ? examples/s]

In [3]:
ds

DatasetDict({
    train: Dataset({
        features: ['image', 'label', 'image_id'],
        num_rows: 18900
    })
    validation: Dataset({
        features: ['image', 'label', 'image_id'],
        num_rows: 6300
    })
    test: Dataset({
        features: ['image', 'label', 'image_id'],
        num_rows: 6300
    })
})

In [9]:
transform = transforms.Compose([
    transforms.ToTensor()
])

# Create a PyTorch Dataset
class CustomImageDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, transform=None):
        self.dataset = dataset
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        image = self.dataset[idx]['image']
        label = self.dataset[idx]['label']
        
        if self.transform:
            image = self.transform(image)
        
        return image, label

# Create datasets for training and validation (assuming split exists)
train_dataset = CustomImageDataset(ds['train'], transform=transform)
val_dataset = CustomImageDataset(ds['validation'], transform=transform)
test_dataset = CustomImageDataset(ds['test'], transform=transform)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4, pin_memory=True)


# Example: Iterate through the training data
for images, labels in train_loader:
    print(images.shape, labels.shape)
    break

torch.Size([64, 3, 256, 256]) torch.Size([64])


In [15]:
class PatchEmbedding(nn.Module):
    def __init__(self, image_size=256, patch_size=16, in_channels=3, embed_dim=768):
        super(PatchEmbedding, self).__init__()
        self.patch_size = patch_size
        self.grid_size = image_size // patch_size
        self.embedding = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.position_embedding = nn.Parameter(torch.randn(1, self.grid_size**2 + 1, embed_dim))
    
    def forward(self, x):
        batch_size = x.size(0)
        x = self.embedding(x)  # (B, E, H', W')
        x = x.flatten(2)  # (B, E, N)
        x = x.transpose(1, 2)  # (B, N, E)
        cls_token = self.position_embedding[:, 0, :].unsqueeze(1).expand(batch_size, 1, -1)
        x = torch.cat((cls_token, x), dim=1)  # (B, N+1, E)
        return x

class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim=768, num_heads=12):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.embed_dim = embed_dim
        self.head_dim = embed_dim // num_heads
        assert self.head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
        
        self.qkv = nn.Linear(embed_dim, embed_dim * 3)
        self.out = nn.Linear(embed_dim, embed_dim)
    
    def forward(self, x):
        B, N, E = x.size()
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        
        attn_scores = (q @ k.transpose(-2, -1)) * (1.0 / self.head_dim**0.5)
        attn_weights = F.softmax(attn_scores, dim=-1)
        
        out = attn_weights @ v
        out = out.transpose(1, 2).reshape(B, N, E)
        out = self.out(out)
        return out

class FeedForward(nn.Module):
    def __init__(self, embed_dim=768, ff_dim=3*768):
        super(FeedForward, self).__init__()
        self.fc1 = nn.Linear(embed_dim, ff_dim)
        self.fc2 = nn.Linear(ff_dim, embed_dim)
    
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

class TransformerBlock(nn.Module):
    def __init__(self, embed_dim=768, num_heads=12, ff_dim=768, dropout=0.1):
        super(TransformerBlock, self).__init__()
        self.attention = MultiHeadAttention(embed_dim, num_heads)
        self.ffn = FeedForward(embed_dim, ff_dim)
        self.ln1 = nn.LayerNorm(embed_dim)
        self.ln2 = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        x2 = self.attention(x)
        x = x + self.dropout(x2)
        x = self.ln1(x)
        x2 = self.ffn(x)
        x = x + self.dropout(x2)
        x = self.ln2(x)
        return x

class VisionTransformer(nn.Module):
    def __init__(self, num_classes, image_size=256, patch_size=16, in_channels=3, embed_dim=768, num_heads=12, ff_dim=3*768, num_layers=6):
        super(VisionTransformer, self).__init__()
        self.patch_embed = PatchEmbedding(image_size, patch_size, in_channels, embed_dim)
        self.transformer_blocks = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, ff_dim) for _ in range(num_layers)
        ])
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(embed_dim),
            nn.Linear(embed_dim, num_classes)
        )
    
    def forward(self, x):
        x = self.patch_embed(x)
        for block in self.transformer_blocks:
            x = block(x)
        x = x[:, 0]  # Take the class token
        x = self.mlp_head(x)
        return x
    
model = VisionTransformer(num_classes=1000)

In [16]:
criterion=nn.CrossEntropyLoss()
optimizer=optim.Adam(model.parameters(), lr=1e-4)

In [17]:
from prettytable import PrettyTable
def count_parameters(model):
    table = PrettyTable(["Modules", "Parameters"])
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad: continue
        params = parameter.numel()
        table.add_row([name, params])
        total_params+=params
    print(table)
    print(f"Total Trainable Params: {total_params}")
    return total_params
count_parameters(model)

+-------------------------------------------+------------+
|                  Modules                  | Parameters |
+-------------------------------------------+------------+
|       patch_embed.position_embedding      |   197376   |
|        patch_embed.embedding.weight       |   589824   |
|         patch_embed.embedding.bias        |    768     |
| transformer_blocks.0.attention.qkv.weight |  1769472   |
|  transformer_blocks.0.attention.qkv.bias  |    2304    |
| transformer_blocks.0.attention.out.weight |   589824   |
|  transformer_blocks.0.attention.out.bias  |    768     |
|    transformer_blocks.0.ffn.fc1.weight    |  1769472   |
|     transformer_blocks.0.ffn.fc1.bias     |    2304    |
|    transformer_blocks.0.ffn.fc2.weight    |  1769472   |
|     transformer_blocks.0.ffn.fc2.bias     |    768     |
|      transformer_blocks.0.ln1.weight      |    768     |
|       transformer_blocks.0.ln1.bias       |    768     |
|      transformer_blocks.0.ln2.weight      |    768    

37003240

In [18]:
device = 'cuda' if torch.cuda.is_available() else torch.device('cpu')
if torch.cuda.device_count() > 1:
    print("Let's use", torch.cuda.device_count(), "GPUs!")
    model = nn.DataParallel(model)
model = model.to(device)
model.to(device)

num_epochs = 10

for epoch in range(num_epochs):
    model.train()
    # Training loop
    for images, labels in tqdm(train_loader):

        images, labels = images.to(device,non_blocking=True), labels.to(device,non_blocking=True)

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # Calculate accuracy
#     model.eval()
#     correct = 0
#     total = 0
#     with torch.no_grad():
#         for images, labels in train_loader:
#             images, labels = images.to(device,non_blocking=True), labels.to(device,non_blocking=True)
#             outputs = model(images)
#             _, predicted = torch.max(outputs.data, 1)
#             total += labels.size(0)
#             correct += (predicted == labels).sum().item()

#         accuracy = 100 * correct / total
#         print(f'Epoch: {epoch+1}/{num_epochs}, Accuracy: {accuracy:.2f}%')


    model.eval()
    correct = 0
    total = 0
    val_loss = 0.0

    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device,non_blocking=True), labels.to(device,non_blocking=True)

            outputs = model(images)
            loss = criterion(outputs, labels)

            # Accumulate the validation loss
            val_loss += loss.item()

            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    # Compute average validation loss
    average_val_loss = val_loss / len(val_loader)
    accuracy = 100 * correct / total

    print(f'Epoch: {epoch+1}/{num_epochs}, Accuracy: {accuracy:.2f}%, Loss: {average_val_loss:.4f}')

Let's use 2 GPUs!


100%|██████████| 296/296 [03:14<00:00,  1.52it/s]


Epoch: 1/10, Accuracy: 38.48%, Loss: 2.1417


100%|██████████| 296/296 [03:14<00:00,  1.52it/s]


Epoch: 2/10, Accuracy: 57.35%, Loss: 1.4431


100%|██████████| 296/296 [03:14<00:00,  1.52it/s]


Epoch: 3/10, Accuracy: 62.41%, Loss: 1.2521


100%|██████████| 296/296 [03:14<00:00,  1.52it/s]


Epoch: 4/10, Accuracy: 64.83%, Loss: 1.1681


100%|██████████| 296/296 [03:14<00:00,  1.52it/s]


Epoch: 5/10, Accuracy: 65.43%, Loss: 1.1507


100%|██████████| 296/296 [03:13<00:00,  1.53it/s]


Epoch: 6/10, Accuracy: 66.87%, Loss: 1.0781


100%|██████████| 296/296 [03:13<00:00,  1.53it/s]


Epoch: 7/10, Accuracy: 67.71%, Loss: 1.0528


100%|██████████| 296/296 [03:13<00:00,  1.53it/s]


Epoch: 8/10, Accuracy: 69.56%, Loss: 1.0106


100%|██████████| 296/296 [03:13<00:00,  1.53it/s]


Epoch: 9/10, Accuracy: 67.32%, Loss: 1.1239


100%|██████████| 296/296 [03:13<00:00,  1.53it/s]


Epoch: 10/10, Accuracy: 68.37%, Loss: 1.1214


For increasing accuracy 

1. number of epochs need to increased with change in learning rate
2. increases number of layers of transformer 
3. larger dataset would always be helpful that means augmentation is a lot helpful (as transformers are learning from 16x16 patches it learns diferently if you rotate some image by some degree)

In [None]:
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = 100 * correct / total
print(f'Test Accuracy: {accuracy:.2f}%')