In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
torch.cuda.empty_cache()

In [4]:
# preprocess
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=2)

Files already downloaded and verified
Files already downloaded and verified


In [5]:
class Bottleneck(nn.Module):
    expansion = 4
    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * self.expansion)
        
        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes * self.expansion:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, planes * self.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * self.expansion)
            )
    
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

In [6]:
class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 64
        
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512 * block.expansion, num_classes)
    
    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)
    
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, out.size()[2:])
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

def Net():
    return ResNet(Bottleneck, [3, 4, 6, 3])

In [7]:
model = Net().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)

In [8]:
def train_model(num_epochs=100):
    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        for i, (images, labels) in enumerate(trainloader):
            images, labels = images.to(device), labels.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            if (i + 1) % 100 == 0:
                print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}], Loss: {running_loss/100:.4f}')
                running_loss = 0.0
        scheduler.step()


In [9]:
def test_model():
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in testloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print(f'Accuracy on test set: {100 * correct / total:.2f}%')
    

In [10]:
train_model(num_epochs=100)

Epoch [1/100], Step [100], Loss: 2.0374
Epoch [1/100], Step [200], Loss: 1.5895
Epoch [1/100], Step [300], Loss: 1.4052
Epoch [2/100], Step [100], Loss: 1.1565
Epoch [2/100], Step [200], Loss: 1.0829
Epoch [2/100], Step [300], Loss: 1.0229
Epoch [3/100], Step [100], Loss: 0.9220
Epoch [3/100], Step [200], Loss: 0.8880
Epoch [3/100], Step [300], Loss: 0.8483
Epoch [4/100], Step [100], Loss: 0.7913
Epoch [4/100], Step [200], Loss: 0.7424
Epoch [4/100], Step [300], Loss: 0.7253
Epoch [5/100], Step [100], Loss: 0.6717
Epoch [5/100], Step [200], Loss: 0.6538
Epoch [5/100], Step [300], Loss: 0.6503
Epoch [6/100], Step [100], Loss: 0.5981
Epoch [6/100], Step [200], Loss: 0.6073
Epoch [6/100], Step [300], Loss: 0.5969
Epoch [7/100], Step [100], Loss: 0.5438
Epoch [7/100], Step [200], Loss: 0.5632
Epoch [7/100], Step [300], Loss: 0.5484
Epoch [8/100], Step [100], Loss: 0.4993
Epoch [8/100], Step [200], Loss: 0.5247
Epoch [8/100], Step [300], Loss: 0.5166
Epoch [9/100], Step [100], Loss: 0.4846


In [11]:
test_model()

Accuracy on test set: 93.97%


In [12]:
# save model
path = './models/best.pth'
torch.save(model.state_dict(), path)

In [13]:
# visualize the first convolutional layer's filters
import matplotlib.pyplot as plt
import numpy as np

def visualize_filters(model, save_path='filters_conv1.png'):
    # (out_channels, in_channels, kernel_size, kernel_size), [64, 3, 3, 3]
    weights = model.conv1.weight.data.cpu().numpy()
    
    weights = (weights - weights.min()) / (weights.max() - weights.min())

    fig, axes = plt.subplots(8, 8, figsize=(12, 12))
    for i in range(64):
        row = i // 8
        col = i % 8
        filter_img = weights[i].transpose(1, 2, 0)
        axes[row, col].imshow(filter_img)
        axes[row, col].axis('off')
    
    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()
    print(f'Filter visualization saved to {save_path}')

In [14]:
# visualize loss landscape
from mpl_toolkits.mplot3d import Axes3D

def compute_loss(model, dataloader, criterion):
    model.eval()
    total_loss = 0.0
    total_samples = 0
    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            total_loss += loss.item() * images.size(0)
            total_samples += images.size(0)
    return total_loss / total_samples

def visualize_loss_landscape(model, dataloader, criterion, save_path='loss_landscape.png'):
    model.eval()
    original_params = [p.clone().detach() for p in model.parameters()]
    
    # generate two random directions
    def normalize_direction(direction):
        norm = sum(torch.sum(d**2) for d in direction)**0.5
        return [d / norm for d in direction]
    
    direction1 = [torch.randn_like(p) for p in model.parameters()]
    direction2 = [torch.randn_like(p) for p in model.parameters()]
    direction1 = normalize_direction(direction1)
    direction2 = normalize_direction(direction2)
    
    grid_size = 20
    alpha = np.linspace(-1, 1, grid_size)
    beta = np.linspace(-1, 1, grid_size)
    losses = np.zeros((grid_size, grid_size))

    for i, a in enumerate(alpha):
        for j, b in enumerate(beta):
            for p, d1, d2, orig in zip(model.parameters(), direction1, direction2, original_params):
                p.data = orig + a * d1 + b * d2
            losses[i, j] = compute_loss(model, dataloader, criterion)

    for p, orig in zip(model.parameters(), original_params):
        p.data = orig

    fig = plt.figure(figsize=(10, 8))
    ax = fig.add_subplot(111, projection='3d')
    X, Y = np.meshgrid(alpha, beta)
    surf = ax.plot_surface(X, Y, losses, cmap='viridis')
    ax.set_xlabel('Direction 1')
    ax.set_ylabel('Direction 2')
    ax.set_zlabel('Loss')
    ax.set_title('Loss Landscape')
    fig.colorbar(surf)
    plt.savefig(save_path)
    plt.close()
    print(f'Loss landscape saved to {save_path}')

In [15]:
# visualize interpretation using Grad-CAM
def visualize_grad_cam(model, dataloader, num_images=10, save_path='grad_cam.png'):
    model.eval()
    classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
    
    # register hook for layer4
    features = []
    gradients = []
    def save_features(module, input, output):
        features.append(output.detach())
    def save_gradients(module, grad_in, grad_out):
        gradients.append(grad_out[0].detach())
    
    model.layer4.register_forward_hook(save_features)
    model.layer4.register_backward_hook(save_gradients)
    
    images, labels = next(iter(dataloader))
    images, labels = images[:num_images].to(device), labels[:num_images].to(device)
    
    # original image
    inv_normalize = transforms.Normalize(
        mean=[-0.4914/0.2023, -0.4822/0.1994, -0.4465/0.2010],
        std=[1/0.2023, 1/0.1994, 1/0.2010]
    )
    
    fig, axes = plt.subplots(2, num_images, figsize=(num_images*2, 4))
    
    for i in range(num_images):
        img = images[i:i+1]
        label = labels[i].item()

        features.clear()
        gradients.clear()
        output = model(img)
        pred = output.argmax(dim=1).item()

        model.zero_grad()
        one_hot = torch.zeros_like(output)
        one_hot[0, label] = 1
        output.backward(gradient=one_hot)
  
        feature = features[0]  # [1, 2048, h, w]
        grad = gradients[0]    # [1, 2048, h, w]
        weights = torch.mean(grad, dim=(2, 3), keepdim=True)
        cam = torch.sum(weights * feature, dim=1, keepdim=True)
        cam = F.relu(cam)
        cam = cam - cam.min()
        cam = cam / cam.max()

        cam = F.interpolate(cam, size=(32, 32), mode='bilinear', align_corners=False)
        cam = cam[0, 0].cpu().numpy()

        img_display = inv_normalize(img[0]).permute(1, 2, 0).cpu().numpy()
        img_display = (img_display - img_display.min()) / (img_display.max() - img_display.min())
        axes[0, i].imshow(img_display)
        axes[0, i].set_title(f'Label: {classes[label]}\nPred: {classes[pred]}')
        axes[0, i].axis('off')

        axes[1, i].imshow(img_display)
        axes[1, i].imshow(cam, cmap='jet', alpha=0.5)
        axes[1, i].set_title('Grad-CAM')
        axes[1, i].axis('off')
    
    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()
    print(f'Grad-CAM visualization saved to {save_path}')

In [16]:
visualize_filters(model)
visualize_loss_landscape(model, testloader, criterion)
visualize_grad_cam(model, testloader)

Filter visualization saved to filters_conv1.png
Loss landscape saved to loss_landscape.png


  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)


Grad-CAM visualization saved to grad_cam.png
