In [1]:
import numpy as np
import random
import torch
import torch.nn as nn 
import torch.optim as optim 
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms, datasets
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torchvision
import matplotlib.pyplot as plt
from torchsummary import summary 
from torchmetrics import Accuracy, F1Score
import os

In [2]:
seed = 42
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

train_transform = transforms.Compose([
    transforms.Resize((224, 224)),          
    transforms.RandomHorizontalFlip(),      
    transforms.RandomRotation(15),          
    transforms.ToTensor()                   
])

test_val_transform = transforms.Compose([
    transforms.Resize((224, 224)),          
    transforms.ToTensor()                  
])


class FFTShiftedMNIST(Dataset):
    def __init__(self,  base_transform, train_flag, fft_crop_size =None):
        self.dataset = datasets.CIFAR100(root='./data', train=train_flag, download=True, transform=base_transform)
        self.fft_crop_size = fft_crop_size
        
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        image, label = self.dataset[idx]
        image = image.to("cuda")
        
        # Apply FFT, shift, and get magnitude and phase
        fft = torch.fft.fft2(image)  # Remove single channel dimension
        fft_shifted = torch.fft.fftshift(fft)
        real = torch.real(fft_shifted)
        imag = torch.imag(fft_shifted)     
        magnitude = torch.abs(fft_shifted)
        phase = torch.angle(fft_shifted)
        
        if self.fft_crop_size is not None:
            # Calculate the center crop region
            center_x, center_y = magnitude.shape[1] // 2, magnitude.shape[2] // 2
            crop_size = self.fft_crop_size // 2
            
            # Crop the magnitude and phase around the center
            imag = imag[:, center_x - crop_size:center_x + crop_size, center_y - crop_size:center_y + crop_size]
            real = real[:, center_x - crop_size:center_x + crop_size, center_y - crop_size:center_y + crop_size]
            magnitude = magnitude[:, center_x - crop_size:center_x + crop_size, center_y - crop_size:center_y + crop_size]
            phase = phase[:, center_x - crop_size:center_x + crop_size, center_y - crop_size:center_y + crop_size]

        
        # Stack magnitude and phase along the channel dimension
        transformed_image = torch.cat((magnitude, phase, real, imag), dim=0) 
        
        return transformed_image, label

# Create dataset instances for train, validation, and test sets
train_dataset = FFTShiftedMNIST(train_transform,train_flag=True, fft_crop_size = 112)
test_dataset = FFTShiftedMNIST(test_val_transform,train_flag=False, fft_crop_size = 112)

# Split the training dataset into train and validation sets
train_size = int(0.8 * len(train_dataset))
val_size = len(train_dataset) - train_size
train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

print(f"Image size: {next(iter(train_dataset.dataset))[0].shape}")
print(f"Train size: {len(train_loader.dataset)}")
print(f"Validation size: {len(val_loader.dataset)}")
print(f"Test size: {len(test_loader.dataset)}")

Files already downloaded and verified
Files already downloaded and verified
Image size: torch.Size([12, 112, 112])
Train size: 40000
Validation size: 10000
Test size: 10000


In [3]:
class BasicNet(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.feature_extractor = nn.Sequential(
            nn.Conv2d(12, 32, kernel_size=3, padding=1),
            nn.ELU(),
            nn.BatchNorm2d(32),
            nn.MaxPool2d(kernel_size=2),
            
            nn.Conv2d(32, 64, kernel_size=3, padding=1), 
            nn.ELU(),
            nn.BatchNorm2d(64),
            nn.MaxPool2d(kernel_size=2),
            
            nn.Conv2d(64,128 , kernel_size=3, padding=1),
            nn.ELU(),
            nn.BatchNorm2d(128),
            nn.MaxPool2d(kernel_size=2),

            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.ELU(),
            nn.BatchNorm2d(256),
            
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Dropout(0.2)
        ) 
        self.classifier = nn.Linear(256, num_classes) 
    def forward(self, x):
        x = self.feature_extractor(x)
        x = x.view(x.size(0), -1)  # Flatten the tensor to shape (batch_size, 256)
        x = self.classifier(x)
        return x

In [4]:
freq_model = BasicNet(100)
device = torch.device("cuda")
freq_model.to(device)
summary(freq_model, (12, 112, 112))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 32, 112, 112]           3,488
               ELU-2         [-1, 32, 112, 112]               0
       BatchNorm2d-3         [-1, 32, 112, 112]              64
         MaxPool2d-4           [-1, 32, 56, 56]               0
            Conv2d-5           [-1, 64, 56, 56]          18,496
               ELU-6           [-1, 64, 56, 56]               0
       BatchNorm2d-7           [-1, 64, 56, 56]             128
         MaxPool2d-8           [-1, 64, 28, 28]               0
            Conv2d-9          [-1, 128, 28, 28]          73,856
              ELU-10          [-1, 128, 28, 28]               0
      BatchNorm2d-11          [-1, 128, 28, 28]             256
        MaxPool2d-12          [-1, 128, 14, 14]               0
           Conv2d-13          [-1, 256, 14, 14]         295,168
              ELU-14          [-1, 256,

## Training & Validation 

In [5]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(freq_model.parameters(), lr = 0.005) 
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5)

In [6]:
save_dir = 'baseline_cnn_freq'
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

EPOCHS = 200
BATCH_SIZE = 64
train_cost, val_cost = [],[]
train_acc, val_acc = [],[]
early_stopping_patience = 10
best_val_loss = float('inf')
early_stop_counter = 0

for epoch in range(EPOCHS):
    train_loss = 0 
    acc_train = Accuracy('multiclass', num_classes=100).to(device)
    freq_model.train().cuda()  # set the model for training 

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()

        pred = freq_model(images)
        curr_loss = criterion(pred, labels)
        train_loss += curr_loss.item()

        curr_loss.backward()
        optimizer.step()

        acc_train(pred, labels)
        
    train_cost.append(train_loss / len(train_loader))
    train_acc.append(acc_train.compute()) 


    val_loss = 0 
    acc_val = Accuracy(task="multiclass", num_classes=100).to(device)
    freq_model.eval().cuda()

    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)

            pred = freq_model(images)

            curr_loss = criterion(pred, labels)
            val_loss += curr_loss.item()

            _, predicted = torch.max(pred, 1)
            acc_val(predicted, labels)

    val_cost.append(val_loss / len(val_loader))
    val_acc.append(acc_val.compute())

    print(f"[{epoch+1}\{EPOCHS}], train-loss: {train_cost[-1]}, train-acc: {acc_train.compute()}, val-loss: {val_cost[-1]}, val-acc: {acc_val.compute()}")
#     torch.save(freq_model.state_dict(), f'baseline_cnn_freq/checkpoint_{epoch + 1}')

    scheduler.step(val_cost[-1])

    if val_cost[-1] < best_val_loss:
        best_val_loss = val_cost[-1]
        early_stop_counter = 0
    else:
        early_stop_counter += 1
        if early_stop_counter >= early_stopping_patience:
            print(f"Early stopping triggered after {epoch+1} epochs")
            break
    
torch.save(freq_model.state_dict(), 'baseline_cnn_freq/freq_model_200_EPOCHS.pth') 

[1\200], train-loss: 3.8319502937316896, train-acc: 0.10300000011920929, val-loss: 5.315420794638858, val-acc: 0.05420000106096268
[2\200], train-loss: 3.2116633590698243, train-acc: 0.20607499778270721, val-loss: 5.142668377821613, val-acc: 0.07649999856948853
[3\200], train-loss: 2.9757780509948732, train-acc: 0.2522749900817871, val-loss: 6.153409222888339, val-acc: 0.061799999326467514
[4\200], train-loss: 2.8316635913848875, train-acc: 0.27992498874664307, val-loss: 3.8336095081013477, val-acc: 0.1671999990940094
[5\200], train-loss: 2.7222075000762938, train-acc: 0.30172500014305115, val-loss: 3.8707129256740496, val-acc: 0.19589999318122864
[6\200], train-loss: 2.629307083129883, train-acc: 0.32374998927116394, val-loss: 4.7485187387770145, val-acc: 0.13580000400543213
[7\200], train-loss: 2.5593336771011352, train-acc: 0.3386000096797943, val-loss: 2.970239580057229, val-acc: 0.27880001068115234
[8\200], train-loss: 2.4978007675170897, train-acc: 0.34942498803138733, val-loss: 

## Testing

In [5]:
final_freq_model = BasicNet(100)
final_freq_model.load_state_dict(torch.load("freq_model_200_EPOCHS.pth"))

<All keys matched successfully>

In [6]:
acc = Accuracy(task="multiclass", num_classes=100).to(device)
f1 = F1Score('multiclass', num_classes=100).to(device)

final_freq_model.eval().cuda()

with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)

        predictions = final_freq_model(images) 

        _, predicted = torch.max(predictions, 1)

        acc(predicted, labels)
        f1(predicted, labels)


print(f"Test Accuracy: {acc.compute().data:.3f}")
print(f"F1 : {f1.compute().data:.3f}")

Test Accuracy: 0.458
F1 : 0.458
