# Pytorch version of the implementation


In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
import torchvision.transforms as transforms
from torchvision.datasets import FashionMNIST
from torchvision.models import efficientnet_v2_s
from torch.optim import lr_scheduler

# Define the transformation pipeline
transform = transforms.Compose([
    transforms.Resize(224), # resize the image to 224x224
    transforms.Grayscale(num_output_channels=3), # convert the image to RGB format
    transforms.ToTensor(), # convert the image to a PyTorch tensor
])

# Create the FashionMNIST datasets with the transformation applied
train_dataset = FashionMNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = FashionMNIST(root='./data', train=False, download=True, transform=transform)

idx = (train_dataset.targets == 3) | (train_dataset.targets == 4)
train_dataset.data = train_dataset.data[idx]
train_dataset.targets = train_dataset.targets[idx]


idx = (test_dataset.targets == 3) | (test_dataset.targets == 4)
test_dataset.data = test_dataset.data[idx]
test_dataset.targets = test_dataset.targets[idx]


N = 128
bs = 32
train_dataset = Subset(train_dataset, range(N))
test_dataset = Subset(test_dataset, range(N))


# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=bs, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=bs, shuffle=False)

# Define the EfficientNet_V2_S model
model = efficientnet_v2_s(pretrained=True)

# Define the loss function
criterion = nn.CrossEntropyLoss()
lr = 0.001
# Define the optimizer
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=0)

scheduler = lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)

# Train the model
num_epochs = 10


In [6]:

for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0
    train_correct = 0.0


    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
  
        optimizer.step()
       
        # Update statistics
        train_loss += loss.item()
        _, predicted = torch.max(output.data, 1)
        train_correct += (predicted == target).sum().item()

    # Calculate statistics for the validation set
    model.eval()
    val_loss = 0.0
    val_correct = 0.0

    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(test_loader):
            output = model(data)
            loss = criterion(output, target)
            
            # Update statistics
            val_loss += loss.item()
            _, predicted = torch.max(output.data, 1)
            val_correct += (predicted == target).sum().item()

    # Print the training and validation statistics for the epoch
    train_loss /= len(train_loader.dataset)
    train_acc = 100.0 * train_correct / len(train_loader.dataset)
    val_loss /= len(test_loader.dataset)
    val_acc = 100.0 * val_correct / len(test_loader.dataset)
    
    print(f'Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')


Epoch 1/10, Train Loss: 0.1633, Train Acc: 33.59%, Val Loss: 0.2387, Val Acc: 0.78%
Epoch 2/10, Train Loss: 0.0173, Train Acc: 96.09%, Val Loss: 0.0804, Val Acc: 77.34%
Epoch 3/10, Train Loss: 0.0026, Train Acc: 98.44%, Val Loss: 0.0316, Val Acc: 89.84%
Epoch 4/10, Train Loss: 0.0008, Train Acc: 99.22%, Val Loss: 0.0341, Val Acc: 85.16%
Epoch 5/10, Train Loss: 0.0006, Train Acc: 99.22%, Val Loss: 0.0438, Val Acc: 82.81%
Epoch 6/10, Train Loss: 0.0040, Train Acc: 98.44%, Val Loss: 0.0354, Val Acc: 85.94%
Epoch 7/10, Train Loss: 0.0022, Train Acc: 99.22%, Val Loss: 0.0201, Val Acc: 88.28%
Epoch 8/10, Train Loss: 0.0005, Train Acc: 99.22%, Val Loss: 0.0141, Val Acc: 89.84%
Epoch 9/10, Train Loss: 0.0003, Train Acc: 100.00%, Val Loss: 0.0110, Val Acc: 89.06%
Epoch 10/10, Train Loss: 0.0001, Train Acc: 100.00%, Val Loss: 0.0107, Val Acc: 90.62%
