<a href="https://colab.research.google.com/github/KavyaS5757/AI-problems/blob/main/Exp%20-%20RIformer/Exp2_Rot_MNIST(MNIST2)_riformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
  import torch.nn as nn
  import torch

  # Define the RotateRelEbd class
  class RotateRelEbd(nn.Module):
      def __init__(self, dim, n_circle=2):
          super().__init__()
          diff_add_values = [-1, 0, 1, 2]
          for i in range(n_circle):  # circle: from outside to inside
              circle_mat = nn.Parameter(torch.zeros(1, dim)+diff_add_values[i])
              setattr(self, f"circle_mat{i}", circle_mat)

      def forward(self, x):
          H, W = x.shape[2], x.shape[3]
          out = torch.zeros_like(x, dtype=torch.float32)
          for i in range(H):
              for j in range(W):
                  dis_2_edge = min(i, j, H-i-1, W-j-1)
                  cir = getattr(self, f"circle_mat{dis_2_edge}")
                  out[:, :, i, j] = x[:, :, i, j].clone() + cir
          return out

  # Define the group_1x1_conv function
  def group_1x1_conv(n_group, in_chan, out_chan):
      return nn.Conv2d(
          in_channels=in_chan,
          out_channels=out_chan,
          kernel_size=1,
          stride=1,
          padding=0,
          bias=False,
          groups=n_group
      )

  # Define the Group_Pixel_Embed class
  class Group_Pixel_Embed(nn.Module):
      def __init__(self, n_group, in_chan, out_chan, spatial_reduce):
          super().__init__()
          self.g1conv = group_1x1_conv(n_group, in_chan, out_chan)
          self.pool = nn.AvgPool2d(kernel_size=spatial_reduce, stride=spatial_reduce)
          self.bn = nn.BatchNorm2d(num_features=out_chan)
          self.ac = nn.LeakyReLU(inplace=True)

      def forward(self, x):
          x = self.g1conv(x)
          x = self.pool(x)
          x = self.ac(self.bn(x))
          return x

  # Define the Mlp class
  class Mlp(nn.Module):
      def __init__(self, in_features, hidden_features=None, out_features=None, drop=0.):
          super().__init__()
          out_features = out_features or in_features
          hidden_features = hidden_features or in_features
          self.fc1 = nn.Linear(in_features, hidden_features)
          self.act = nn.GELU()
          self.fc2 = nn.Linear(hidden_features, out_features)
          self.drop = nn.Dropout(drop)

      def forward(self, x):
          x = self.fc1(x)
          x = self.act(x)
          x = self.drop(x)
          x = self.fc2(x)
          x = self.drop(x)
          return x

  # Define the Attention class
  class Attention(nn.Module):
      def __init__(self, dim, num_heads=8, attn_drop=0., proj_drop=0.):
          super().__init__()
          self.dim = dim
          self.num_heads = num_heads
          self.q = nn.Linear(dim, dim, bias=True)
          self.kv = nn.Linear(dim, dim * 2, bias=True)
          self.attn_drop = nn.Dropout(attn_drop)
          self.proj = nn.Linear(dim, dim)
          self.proj_drop = nn.Dropout(proj_drop)

      def forward(self, x):
          B, N, C = x.shape
          q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
          kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
          k, v = kv[0], kv[1]

          attn = (q @ k.transpose(-2, -1))
          attn = attn.softmax(dim=-1)
          attn = self.attn_drop(attn)

          x = (attn @ v).transpose(1, 2).reshape(B, N, C)
          x = self.proj(x)
          x = self.proj_drop(x)

          return x

  # Define the Block class
  class Block(nn.Module):
      def __init__(self, dim, num_heads=8, mlp_ratio=1, drop=0., attn_drop=0.):
          super().__init__()
          self.norm1 = nn.LayerNorm(dim)
          self.attn = Attention(dim, num_heads=num_heads, attn_drop=attn_drop, proj_drop=drop)
          self.norm2 = nn.LayerNorm(dim)
          mlp_hidden_dim = int(dim * mlp_ratio)
          self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, drop=drop)

      def forward(self, x):
          x = x + self.attn(self.norm1(x))
          x = x + self.mlp(self.norm2(x))
          return x

  # Define the RIFormer class
  class RIFormer(nn.Module):
      def __init__(self, emb_type=None, C1=32, C2=64, C3=128, C4=256, mlp_ratio=2, n_class=10):
          super().__init__()
          self.emb_type = emb_type
          self.gpe_1 = Group_Pixel_Embed(n_group=1, in_chan=3, out_chan=C1, spatial_reduce=4)  # in_chan changed to 1
          # rot_ebd or pos_ebd
          self.rot_ebd_1 = RotateRelEbd(C1, n_circle=4)  # 4,2,1,1
          self.pos_ebd_1 = nn.Parameter(torch.zeros(1, 8*8, C1))  # 8 * 8 = fmap size
          self.block_1_1 = Block(dim=C1, num_heads=1, mlp_ratio=mlp_ratio)
          self.block_1_2 = Block(dim=C1, num_heads=1, mlp_ratio=mlp_ratio)
          self.many_1 = nn.Sequential(self.block_1_1, self.block_1_2)

          self.gpe_2 = Group_Pixel_Embed(n_group=1, in_chan=C1, out_chan=C2, spatial_reduce=2)
          self.rot_ebd_2 = RotateRelEbd(C2, n_circle=2)  # 4,2,1,1
          self.pos_ebd_2 = nn.Parameter(torch.zeros(1, 4*4, C2))
          self.block_2_1 = Block(dim=C2, num_heads=2, mlp_ratio=mlp_ratio)
          self.block_2_2 = Block(dim=C2, num_heads=2, mlp_ratio=mlp_ratio)
          self.many_2 = nn.Sequential(self.block_2_1, self.block_2_2)

          self.gpe_3 = Group_Pixel_Embed(n_group=1, in_chan=C2, out_chan=C3, spatial_reduce=2)
          self.rot_ebd_3 = RotateRelEbd(C3, n_circle=1)  # 4,2,1,1
          self.pos_ebd_3 = nn.Parameter(torch.zeros(1, 2*2, C3))
          self.block_3_1 = Block(dim=C3, num_heads=4, mlp_ratio=mlp_ratio)
          self.block_3_2 = Block(dim=C3, num_heads=4, mlp_ratio=mlp_ratio)
          self.many_3 = nn.Sequential(self.block_3_1, self.block_3_2)

          self.gpe_4 = Group_Pixel_Embed(n_group=1, in_chan=C3, out_chan=C4, spatial_reduce=2)
          self.rot_ebd_4 = RotateRelEbd(C4, n_circle=1)  # 4,2,1,1
          self.pos_ebd_4 = nn.Parameter(torch.zeros(1, 1*1, C4))
          self.block_4_1 = Block(dim=C4, num_heads=8, mlp_ratio=mlp_ratio)
          self.block_4_2 = Block(dim=C4, num_heads=8, mlp_ratio=mlp_ratio)
          self.many_4 = nn.Sequential(self.block_4_1, self.block_4_2)

          self.cls_layer = nn.Linear(C4, n_class)

      def forward(self, x):
            B, C, H, W = x.shape[0], x.shape[1], x.shape[2], x.shape[3]

            x = self.gpe_1(x)
            if self.emb_type == 'rot':
                x = self.rot_ebd_1(x)
            x = x.flatten(2).transpose(1, 2)
            if self.emb_type == 'pos':
                x = x + self.pos_ebd_1
            x = self.many_1(x)
            x = x.reshape(B, H//4, W//4, -1).permute(0, 3, 1, 2).contiguous()

            x = self.gpe_2(x)
            if self.emb_type == 'rot':
                x = self.rot_ebd_2(x)
            x = x.flatten(2).transpose(1, 2)
            if self.emb_type == 'pos':
                x = x + self.pos_ebd_2
            x = self.many_2(x)
            x = x.reshape(B, H//8, W//8, -1).permute(0, 3, 1, 2).contiguous()

            x = self.gpe_3(x)
            if self.emb_type == 'rot':
                x = self.rot_ebd_3(x)
            x = x.flatten(2).transpose(1, 2)
            if self.emb_type == 'pos':
                x = x + self.pos_ebd_3
            x = self.many_3(x)
            x = x.reshape(B, H//16, W//16, -1).permute(0, 3, 1, 2).contiguous()

            x = self.gpe_4(x)
            if self.emb_type == 'rot':
                x = self.rot_ebd_4(x)
            x = x.flatten(2).transpose(1, 2)
            if self.emb_type == 'pos':
                x = x + self.pos_ebd_4
            x = self.many_4(x)
            x = x.reshape(B, H//32, W//32, -1).permute(0, 3, 1, 2).contiguous()
            x = x.flatten(start_dim=1)

            x = self.cls_layer(x)

            return x

  if __name__ == "__main__":
      input_t = torch.rand(8, 3, 32, 32)
      net = RIFormer(emb_type='rot', n_class=10)
      print(net(input_t).shape)

torch.Size([8, 10])


In [2]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import datasets

In [3]:
# Define transformation for the MNIST dataset with random rotations
train_transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.RandomChoice([
        transforms.RandomRotation(degrees=(0, 0)),
        transforms.RandomRotation(degrees=(90, 90)),
        transforms.RandomRotation(degrees=(180, 180)),
        transforms.RandomRotation(degrees=(270, 270)),
    ]),
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,)),
    transforms.Lambda(lambda x: x.repeat(3, 1, 1))  # Repeat the grayscale channel 3 times
])


In [4]:
# Load the MNIST dataset with the specified transformation
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=train_transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=train_transform)


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:01<00:00, 5222779.45it/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 153579.52it/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:01<00:00, 1445201.41it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 2399613.15it/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw






In [5]:
# Split the full dataset into 80% for training and 20% for testing
train_size = int(0.8 * len(train_dataset))
test_size = len(train_dataset) - train_size
train_dataset, _ = random_split(train_dataset, [train_size, test_size])

In [6]:
# Further split the training dataset into 85% for training and 15% for validation
val_size = int(0.15 * len(train_dataset))
train_size = len(train_dataset) - val_size
train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])


In [7]:
# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)


In [8]:
# Define the model, loss function, and optimizer
net = RIFormer(emb_type='rot', n_class=10)

In [9]:
# Training setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = RIFormer(emb_type='rot', n_class=10).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, verbose=True)




In [10]:
# Function to validate the model
def validate_model(model, val_loader, criterion):
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
    return val_loss / len(val_loader)

In [11]:
# Function to train the model
def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs=30):
    best_val_loss = float('inf')
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        # Validate the model
        val_loss = validate_model(model, val_loader, criterion)
        print(f"Epoch [{epoch+1}/{num_epochs}], Training Loss: {running_loss/len(train_loader):.4f}, Validation Loss: {val_loss:.4f}")

        # Step the scheduler
        scheduler.step(val_loss)

        # Save the model if validation loss has improved
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), 'best_model.pth')
            print(f"Model saved at epoch {epoch+1} with validation loss {best_val_loss:.4f}")

    print("Training complete. Best validation loss: {:.4f}".format(best_val_loss))


In [12]:
# Function to test the model
def test_model(model, test_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print("Test Accuracy: {:.4f}".format(correct / total))

In [13]:
# Train the model
train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs=30)


Epoch [1/30], Training Loss: 1.0107, Validation Loss: 1.3082
Model saved at epoch 1 with validation loss 1.3082
Epoch [2/30], Training Loss: 0.6748, Validation Loss: 0.9407
Model saved at epoch 2 with validation loss 0.9407
Epoch [3/30], Training Loss: 0.5379, Validation Loss: 0.8886
Model saved at epoch 3 with validation loss 0.8886
Epoch [4/30], Training Loss: 0.4749, Validation Loss: 0.7932
Model saved at epoch 4 with validation loss 0.7932
Epoch [5/30], Training Loss: 0.4206, Validation Loss: 0.8327
Epoch [6/30], Training Loss: 0.3841, Validation Loss: 1.0510
Epoch [7/30], Training Loss: 0.3490, Validation Loss: 0.5105
Model saved at epoch 7 with validation loss 0.5105
Epoch [8/30], Training Loss: 0.3292, Validation Loss: 0.3853
Model saved at epoch 8 with validation loss 0.3853
Epoch [9/30], Training Loss: 0.3160, Validation Loss: 0.6834
Epoch [10/30], Training Loss: 0.3005, Validation Loss: 1.3532
Epoch [11/30], Training Loss: 0.2809, Validation Loss: 0.5253
Epoch [12/30], Traini

In [14]:
# Test the model on the same rotation
print("Testing on rotated MNIST dataset")
test_model(model, test_loader)

Testing on rotated MNIST dataset
Test Accuracy: 0.9207
