In [2]:
# Fix for AttributeError: module 'collections' has no attribute 'MutableMapping'
import collections
import collections.abc
for type_name in collections.abc.__all__:
    if not hasattr(collections, type_name):
        setattr(collections, type_name, getattr(collections.abc, type_name))

In [None]:
N = 2

import os
import torch
import torchvision
from torch.utils.data import TensorDataset, DataLoader
from catalyst import utils
from catalyst.contrib.datasets import MNIST
from torch import nn

utils.set_global_seed(N)

import six

if not hasattr(six, 'string_classes'):
    if hasattr(six, 'string_types'):
        six.string_classes = six.string_types
    else:
        # For newer Python versions
        six.string_classes = (str,)

# Add this to ensure torch._six has string_classes
import torch
if not hasattr(torch, '_six'):
    torch._six = six
elif not hasattr(torch._six, 'string_classes'):
    torch._six.string_classes = six.string_classes

if not hasattr(torch, '_six'):
    torch._six = six
    
import torch.nn.functional as F

utils.set_global_seed(N)
train_dataset = MNIST(root=os.getcwd(), train=True, download=True)
val_dataset = MNIST(root=os.getcwd(), train=False)
train_dataloader = DataLoader(train_dataset, batch_size=128)
val_dataloader = DataLoader(val_dataset, batch_size=128)

class Identical(nn.Module):
	def forward(self, x):
		return x

class Flatten(nn.Module):
	def forward(self, x):
		batch_size = x.size(0)
		
		return x.view(batch_size, -1)

activation = Identical

class Flatten(nn.Module):
    def forward(self, x):
        batch_size = x.size(0)
        return x.view(batch_size, -1)

class EnhancedMNISTModel(nn.Module):
    def __init__(self, dropout_rate=0.2):
        super(EnhancedMNISTModel, self).__init__()
        self.flatten = Flatten()
        
        self.fc1 = nn.Linear(28*28, 512)
        self.bn1 = nn.BatchNorm1d(512)
        self.dropout1 = nn.Dropout(dropout_rate)
        
        self.fc2 = nn.Linear(512, 256)
        self.bn2 = nn.BatchNorm1d(256)
        self.dropout2 = nn.Dropout(dropout_rate)
        
        self.fc3 = nn.Linear(256, 128)
        self.bn3 = nn.BatchNorm1d(128)
        self.dropout3 = nn.Dropout(dropout_rate)
        
        self.fc4 = nn.Linear(128, 10)
    
    def forward(self, x):
        x = self.flatten(x)
        
        x = F.relu(self.bn1(self.fc1(x)))
        x = self.dropout1(x)
        
        x = F.relu(self.bn2(self.fc2(x)))
        x = self.dropout2(x)
        
        x = F.relu(self.bn3(self.fc3(x)))
        x = self.dropout3(x)
        
        x = self.fc4(x)
        return x

# Initialize the enhanced model
model = EnhancedMNISTModel()




















In [None]:
criterion = nn.CrossEntropyLoss()

optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, 'max', patience=3, factor=0.5, verbose=True
)

loaders = {"train": train_dataloader, "valid": val_dataloader}

max_epochs = N * 5

accuracy = {"train": [], "valid": []}

for epoch in range(max_epochs):
	epoch_correct = 0
	epoch_all = 0
	for k, dataloader in loaders.items():
		epoch_correct = 0
		epoch_all = 0
		for x_batch, y_batch in dataloader:
			if k == "train":
				model.train()
				optimizer.zero_grad()
				x_batch = x_batch.float()
				x_batch = (x_batch - x_batch.mean()) / (x_batch.std() + 1e-8)
				outp = model(x_batch.float().unsqueeze(1))
			else:
				model.eval()
				with torch.no_grad():
					x_batch = x_batch.float()
					x_batch = (x_batch - x_batch.mean()) / (x_batch.std() + 1e-8)
					outp = model(x_batch.unsqueeze(1))
			preds = outp.argmax(-1)
			correct = (preds == y_batch).sum()
			all = len(y_batch)
			epoch_correct += correct.item()
			epoch_all += all
			if k == "train":
				loss = criterion(outp, y_batch)
				loss.backward()
				optimizer.step()
		print(f"Epoch: {epoch+1}")
		print(f"Loader: {k}. Accuracy: {epoch_correct/epoch_all}")
		accuracy[k].append(epoch_correct/epoch_all)



In [None]:
import matplotlib.pyplot as plt

%matplotlib inline

epochs = list(range(1, max_epochs + 1))

plt.figure(figsize=(10, 6))
plt.plot(epochs, accuracy['train'], 'b-', label='Training Accuracy')
plt.plot(epochs, accuracy['valid'], 'r-', label='Validation Accuracy')
plt.title('Training and Validation Accuracy over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.grid(True)
plt.legend()



