In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import torch.optim as optim
import torchvision

In [2]:
class ChannelAttentionModule(nn.Module):
    def __init__(self, channels, reduction_ratio=16):
        super(ChannelAttentionModule, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        
        self.fc = nn.Sequential(
            nn.Conv2d(channels, max(1, channels // reduction_ratio), 1, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(max(1, channels // reduction_ratio), channels, 1, bias=False)
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc(self.avg_pool(x))
        max_out = self.fc(self.max_pool(x))
        out = avg_out + max_out
        return self.sigmoid(out) * x

In [3]:
class SpatialAttentionModule(nn.Module):
    def __init__(self):
        super(SpatialAttentionModule, self).__init__()
        self.conv = nn.Conv2d(2, 1, kernel_size=7, padding=3, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        
        avg_out = torch.mean(x, dim=1, keepdim=True)  
        max_out, _ = torch.max(x, dim=1, keepdim=True)  
        
        combined = torch.cat([avg_out, max_out], dim=1)  
        attention = self.conv(combined)  
               
        attention_weights = self.sigmoid(attention)       
        
        return attention_weights * x  

In [4]:
class TemporalAttentionModule(nn.Module):
    def __init__(self, channels, reduction_ratio=16):
        super(TemporalAttentionModule, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool3d(1)  
        self.max_pool = nn.AdaptiveMaxPool3d(1)
        
        self.fc = nn.Sequential(
            nn.Linear(channels, max(1, channels // reduction_ratio)),
            nn.ReLU(inplace=True),
            nn.Linear(max(1, channels // reduction_ratio), channels)
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        B, C, T, H, W = x.shape
        
        avg_out = self.avg_pool(x).view(B, C)  
        max_out = self.max_pool(x).view(B, C)  
        
        avg_weights = self.fc(avg_out)  
        max_weights = self.fc(max_out)  
        combined_weights = self.sigmoid(avg_weights + max_weights)          
        combined_weights = combined_weights.view(B, C, 1, 1, 1)
        
        return x * combined_weights.expand_as(x)

In [5]:
class ResNetBlock3D(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(ResNetBlock3D, self).__init__()
       
        self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=(1, 3, 3), 
                              stride=(1, stride, stride), padding=(0, 1, 1), bias=False)
        self.bn1 = nn.BatchNorm3d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=(1, 3, 3), 
                              padding=(0, 1, 1), bias=False)
        self.bn2 = nn.BatchNorm3d(out_channels)
        
        
        self.channel_att = ChannelAttentionModule(out_channels)
        self.spatial_att = SpatialAttentionModule()
        self.temporal_att = TemporalAttentionModule(out_channels)  # New
        
        
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv3d(in_channels, out_channels, kernel_size=1, 
                          stride=(1, stride, stride), bias=False),
                nn.BatchNorm3d(out_channels)
            )

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        
        
        out = self.channel_att(out)  
        out = self.spatial_att(out)   
        out = self.temporal_att(out)  
        
        out += self.shortcut(residual)
        out = self.relu(out)
        return out

In [6]:
class ResNetSequentialMNIST(nn.Module):
    def __init__(self, num_classes=10, num_frames=5):  
        super(ResNetSequentialMNIST, self).__init__()
        self.prep = nn.Sequential(
            nn.Conv3d(1, 64, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=False),
            nn.BatchNorm3d(64),
            nn.ReLU(inplace=True)
        )
        
        self.layer1 = self._make_layer(64, 64, 2)
        self.layer2 = self._make_layer(64, 128, 2, stride=2)
        self.layer3 = self._make_layer(128, 256, 2, stride=2)
        
        
        self.temp_pool = nn.AdaptiveAvgPool3d((1, None, None))  
        self.avg_pool = nn.AdaptiveAvgPool2d(1)  
        self.fc = nn.Linear(256, num_classes)

    def _make_layer(self, in_channels, out_channels, num_blocks, stride=1):
        layers = []
        layers.append(ResNetBlock3D(in_channels, out_channels, stride))
        for _ in range(1, num_blocks):
            layers.append(ResNetBlock3D(out_channels, out_channels))
        return nn.Sequential(*layers)

    def forward(self, x):
        
        x = self.prep(x)  
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)  
        
        
        x = self.temp_pool(x).squeeze(2)  
        x = self.avg_pool(x)  
        x = torch.flatten(x, 1)  
        x = self.fc(x)
        return x

In [18]:
def train_model(model, train_loader, criterion, optimizer, device):
    model.train()
    for epoch in range(4):  
        total_loss = 0
        correct = 0
        total = 0
        
        for batch_idx, (data, target) in enumerate(train_loader):
            
            data, target = data.to(device), target.to(device)          
            optimizer.zero_grad()            
            
            output = model(data)
            loss = criterion(output, target)           
            
            loss.backward()
            optimizer.step()
            
            
            _, predicted = torch.max(output.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()
            
            total_loss += loss.item()
            
           
            if batch_idx % 100 == 0:
                print(f'Epoch [{epoch+1}], Batch [{batch_idx}/{len(train_loader)}], Loss: {loss.item():.4f}')
        
    
        print(f'Epoch [{epoch+1}], Loss: {total_loss/len(train_loader):.4f}, '
              f'Accuracy: {100 * correct / total:.2f}%')

In [19]:
def evaluate_model(model, test_loader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in test_loader:
            
            data, target = data.to(device), target.to(device)
            
            output = model(data)
            _, predicted = torch.max(output.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()
    
    print(f'Test Accuracy: {100 * correct / total:.2f}%')

In [20]:
def main():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")   

    batch_size = 64
    learning_rate = 0.001    

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])    

    train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)    
    
    model = ResNetMNIST().to(device)    
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    
    try:        
        train_model(model, train_loader, criterion, optimizer, device)      
        evaluate_model(model, test_loader, device)
    
    except Exception as e:
        print(f"An error occurred during training or evaluation: {e}")
        import traceback
        traceback.print_exc()

In [None]:
if __name__ == '__main__':
    main()