In [1]:
import sys
sys.path.append("..")
import torchvision.transforms as transforms
from Utils.TinyImageNet_loader import get_tinyimagenet_dataloaders

image_size =224
tiny_transform_train = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.Resize((image_size, image_size)), 
        transforms.RandomCrop(image_size, padding=5),
        transforms.RandomRotation(10),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])
tiny_transform_val = transforms.Compose([
        transforms.Resize((image_size, image_size)), 
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])
tiny_transform_test = transforms.Compose([
        transforms.Resize((image_size, image_size)), 
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])

train_loader, val_loader, test_loader = get_tinyimagenet_dataloaders(
                                                    data_dir = '../datasets',
                                                    transform_train=tiny_transform_train,
                                                    transform_val=tiny_transform_val,
                                                    transform_test=tiny_transform_test,
                                                    batch_size=64,
                                                    image_size=image_size)


In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import numpy as np
from tqdm import tqdm
import pandas as pd
from einops import rearrange
from einops.layers.torch import Rearrange
import matplotlib.pyplot as plt
from torchvision.datasets import ImageFolder

# Modified Swin Transformer implementation
class AdaptivePatchEmbedding(nn.Module):
    def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2
        self.scale_factor = nn.Parameter(torch.ones(1) * 1.0)
        self.proj = nn.Conv2d(in_chans, embed_dim, 
                             kernel_size=patch_size, stride=patch_size)
        
    def forward(self, x):
        scaled_patch = self.patch_size * torch.sigmoid(self.scale_factor)
        B, C, H, W = x.shape
        scaled_h = int(H * self.patch_size / scaled_patch)
        scaled_w = int(W * self.patch_size / scaled_patch)
        x = F.interpolate(x, size=(scaled_h, scaled_w), mode='bilinear')
        x = self.proj(x)
        return x

class MultiScaleAttention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5
        self.num_scales = 3
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        self.scale_pos_embed = nn.Parameter(torch.zeros(1, self.num_scales, dim // self.num_scales))
        
    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        
        scale_q = q.chunk(self.num_scales, dim=1)
        scale_k = k.chunk(self.num_scales, dim=1)
        scale_v = v.chunk(self.num_scales, dim=1)
        
        scale_outputs = []
        for i in range(self.num_scales):
            sq, sk, sv = scale_q[i], scale_k[i], scale_v[i]
            scale_pos = self.scale_pos_embed[:, i:i+1].expand(B, -1, -1)
            sq = sq + scale_pos.reshape(B, -1, C // self.num_heads)
            
            attn = (sq @ sk.transpose(-2, -1)) * self.scale
            attn = attn.softmax(dim=-1)
            attn = self.attn_drop(attn)
            
            scale_output = (attn @ sv)
            scale_outputs.append(scale_output)
        
        x = torch.cat(scale_outputs, dim=1)
        x = x.transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

    class ModifiedSwinTransformer(nn.Module):
        def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=200,
                    embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
                    window_size=7, mlp_ratio=4., qkv_bias=True,
                    drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1):
            super().__init__()
            
            self.patch_embed = AdaptivePatchEmbedding(
                img_size=img_size, patch_size=patch_size,
                in_chans=in_chans, embed_dim=embed_dim)
            
            # Calculate correct number of patches
            self.img_size = img_size
            self.patch_size = patch_size
            self.num_patches = (img_size // patch_size) ** 2
            
            # Initialize position embedding with correct size
            self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches, embed_dim))
            
            dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
            
            self.stages = nn.ModuleList()
            current_num_patches = self.num_patches
        
        for i in range(len(depths)):
            stage = nn.Sequential(*[
                ModifiedSwinTransformerBlock(
                    dim=embed_dim * (2**i),
                    num_heads=num_heads[i],
                    window_size=window_size,
                    shift_size=0 if (j % 2 == 0) else window_size // 2,
                    mlp_ratio=mlp_ratio,
                    qkv_bias=qkv_bias,
                    drop=drop_rate,
                    attn_drop=attn_drop_rate,
                    drop_path=dpr[sum(depths[:i])+j]
                )
                for j in range(depths[i])
            ])
            self.stages.append(stage)
            
            if i < len(depths) - 1:
                # Update current number of patches for next stage
                next_num_patches = current_num_patches // 4
                self.stages.append(
                    nn.Sequential(
                        Rearrange('b (h w) c -> b c h w', h=int(np.sqrt(current_num_patches))),
                        nn.Conv2d(embed_dim * (2**i), embed_dim * (2**(i+1)), kernel_size=2, stride=2),
                        Rearrange('b c h w -> b (h w) c')
                    )
                )
                current_num_patches = next_num_patches
        
        self.norm = nn.LayerNorm(embed_dim * (2**(len(depths)-1)))
        self.head = nn.Linear(embed_dim * (2**(len(depths)-1)), num_classes)
        
        nn.init.trunc_normal_(self.pos_embed, std=.02)
        self.apply(self._init_weights)
        
    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
            
    def forward(self, x):
        x = self.patch_embed(x)
        B, C, H, W = x.shape
        x = x.flatten(2).transpose(1, 2)
        
        # Verify shapes before adding position embedding
        if x.size(1) != self.pos_embed.size(1):
            raise ValueError(f"Position embedding size mismatch. Input sequence length: {x.size(1)}, "
                           f"Position embedding size: {self.pos_embed.size(1)}. "
                           f"Expected sequence length: {self.num_patches}")
        
        x = x + self.pos_embed
        
        for stage in self.stages:
            x = stage(x)
        
        x = self.norm(x).mean(dim=1)
        x = self.head(x)
        return x

# ... (rest of the training code remains the same)

def main():
    # Set random seed for reproducibility
    torch.manual_seed(42)
    torch.cuda.manual_seed(42)
    
    # Training settings
    batch_size = 32
    num_epochs = 100
    learning_rate = 1e-3
    weight_decay = 0.05
    
    # Print model configuration
    img_size = 224
    patch_size = 4
    num_patches = (img_size // patch_size) ** 2
    print(f"Input image size: {img_size}x{img_size}")
    print(f"Patch size: {patch_size}x{patch_size}")
    print(f"Number of patches: {num_patches}")
    
    # Device configuration
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # ... (rest of the main function remains the same)

if __name__ == '__main__':
    main()


Epoch 1/100


Training:   0%|          | 0/1563 [00:00<?, ?it/s]


RuntimeError: The size of tensor a (5776) must match the size of tensor b (3136) at non-singleton dimension 1