# Assignment 1 - Code Example - Part A

This code baseline is inspired by and modified from [this great tutorial](https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html).

This code can achieve an accuracy of approximately 86.50% on CIFAR-10. Please set up the environment and run your experiments starting from this baseline. You are expected to achieve an accuracy higher than this baseline.

## data and pre

In [2]:
# import some necessary packages
import torch
import torch.nn as nn
import torch.optim as optim
 
import torchvision.datasets as tv_datasets
import torchvision.transforms as tv_transforms

from time import time

In [3]:
# some experimental setup
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

num_epochs = 128
batch_size = 64
num_workers = 2
print_every = 200

optim_name = "Adam"
optim_kwargs = dict(
    lr=3e-4,
    weight_decay=1e-6,
)

# preprocessing pipeline for input images
transformation = dict()
for data_type in ("train", "test"):
    is_train = data_type=="train"
    transformation[data_type] = tv_transforms.Compose(([
        tv_transforms.RandomRotation(degrees=15),
        tv_transforms.RandomHorizontalFlip(),
        tv_transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
    ] if is_train else []) + 
    [
        tv_transforms.ToTensor(),
        tv_transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
    ])

In [4]:
# prepare datasets

dataset, loader = {}, {}
for data_type in ("train", "test"):
    is_train = data_type=="train"
    dataset[data_type] = tv_datasets.CIFAR10(
        
        root="./data", train=is_train, download=False, transform=transformation[data_type],
    )
    loader[data_type] = torch.utils.data.DataLoader(
        dataset[data_type], batch_size=batch_size, shuffle=is_train, num_workers=num_workers,
    )


## model

### ConvBlock

In [5]:
class ConvBlock(nn.Module):
    def __init__(self,in_channels, out_channels, kernel_size, stride=1, padding=0, Is_BN=True, Is_reg=True):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.relu = nn.ReLU(inplace=True)
        self.Is_reg = Is_reg
        self.Is_BN = Is_BN
        if Is_BN:
            self.bn = nn.BatchNorm2d(out_channels)
        if Is_reg:
            self.maxpool = nn.MaxPool2d(2)
            self.dp = nn.Dropout(0.3)
        
    
    def forward(self, x):
        # print(f"[DEBUG] Input shape: {x.shape if x is not None else 'NULL'}")
        x = self.conv(x)
        x = self.relu(x)
        if self.Is_BN:
            x = self.bn(x)
        if self.Is_reg:
            x = self.maxpool(x)
            x = self.dp(x)
        # print(f"[DEBUG] Post-conv shape: {x.shape}")
        return x

class ConvBs(nn.Module):
    def __init__(self, num_layers,layer_dict):
        super().__init__()
        self.layers = nn.ModuleList([
            ConvBlock(**layer_dict[i])
            for i in range(num_layers)
        ])


    def forward(self, x, train=True):
        for layer in self.layers:
            x = layer(x)
        return x

### initial net

In [6]:
# our network architecture

net = nn.Sequential(
    nn.Conv2d(3, 128, 3, padding=1), nn.ReLU(inplace=True), nn.MaxPool2d(2), nn.Dropout(0.3),
    nn.Conv2d(128, 256, 3, padding=1), nn.ReLU(inplace=True), nn.MaxPool2d(2), nn.Dropout(0.3),
    nn.Conv2d(256, 512, 3, padding=1), nn.ReLU(inplace=True),
    nn.Conv2d(512, 512, 3, padding=1), nn.ReLU(inplace=True),
    nn.Conv2d(512, 256, 3, padding=1), nn.ReLU(inplace=True), nn.MaxPool2d(2), nn.Dropout(0.3),
    nn.Flatten(),
    nn.Linear(256 * 4 * 4, 512), nn.ReLU(inplace=True), nn.Dropout(0.5),
    nn.Linear(512, 256), nn.ReLU(inplace=True), nn.Dropout(0.5),
    nn.Linear(256, 128), nn.ReLU(inplace=True), nn.Dropout(0.5),
    nn.Linear(128, 10),
)

# move to device
net.to(device)

# print the number of parameters
print(f"number of parameters: {sum(p.numel() for p in net.parameters() if p.requires_grad) / 1_000_000:.2f}M")

number of parameters: 7.28M


In [6]:
# 模块化设计
# 增加filter 和layers
# google net : v1:NIN+global pooling
# v2: BN + 5*5 -> 2 3*3
# v3: factorization
# residual

### ViT

In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.init import xavier_uniform_, normal_

class AddPositionEmbs(nn.Module):
    def __init__(self, seq_len, emb_dim):
        super().__init__()
        self.pos_embedding = nn.Parameter(torch.randn(1, seq_len, emb_dim) * 0.02)
    
    def forward(self, x):
        return x + self.pos_embedding

class MlpBlock(nn.Module):
    def __init__(self, in_dim, mlp_dim, dropout=0.1):
        super().__init__()
        self.fc1 = nn.Linear(in_dim, mlp_dim)
        self.fc2 = nn.Linear(mlp_dim, in_dim)
        self.dropout = nn.Dropout(dropout)
        
        # 初始化参数
        xavier_uniform_(self.fc1.weight)
        normal_(self.fc1.bias, std=1e-6)
        xavier_uniform_(self.fc2.weight)
        normal_(self.fc2.bias, std=1e-6)

    def forward(self, x):
        x = self.fc1(x)
        x = F.gelu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        return self.dropout(x)

class Encoder1DBlock(nn.Module):
    def __init__(self, hidden_dim, mlp_dim, num_heads, dropout=0.1, attn_dropout=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(hidden_dim)
        self.attn = nn.MultiheadAttention(
            embed_dim=hidden_dim,
            num_heads=num_heads,
            dropout=attn_dropout,
            batch_first=True
        )
        self.dropout = nn.Dropout(dropout)
        self.norm2 = nn.LayerNorm(hidden_dim)
        self.mlp = MlpBlock(hidden_dim, mlp_dim, dropout)
        
        # 注意力层初始化
        xavier_uniform_(self.attn.in_proj_weight)
        normal_(self.attn.in_proj_bias, std=1e-6)
        xavier_uniform_(self.attn.out_proj.weight)
        normal_(self.attn.out_proj.bias, std=1e-6)

    def forward(self, x):
        attn_output, _ = self.attn(
            query=self.norm1(x),
            key=self.norm1(x),
            value=self.norm1(x)
        )
        x = x + self.dropout(attn_output)
        x = x + self.mlp(self.norm2(x))
        return x

class Encoder(nn.Module):
    def __init__(self, num_layers, hidden_dim, mlp_dim, num_heads, dropout=0.1, attn_dropout=0.1):
        super().__init__()
        self.layers = nn.ModuleList([
            Encoder1DBlock(hidden_dim, mlp_dim, num_heads, dropout, attn_dropout)
            for _ in range(num_layers)
        ])
        self.pos_emb = AddPositionEmbs(seq_len=65, emb_dim=hidden_dim)  # 默认ViT-B/16
        self.dropout = nn.Dropout(dropout)
        self.norm = nn.LayerNorm(hidden_dim)

    def forward(self, x, train=True):
        x = self.pos_emb(x)
        x = self.dropout(x) if train else x
        for layer in self.layers:
            x = layer(x)
        return self.norm(x)

class VisionTransformer(nn.Module):
    def __init__(self, 
                 num_classes, 
                 img_size=224,
                 patch_size=16,
                 hidden_dim=768,
                 num_layers=12,
                 num_heads=12,
                 mlp_dim=3072,
                 dropout=0.1,
                 attn_dropout=0.1,
                 representation_size=None):
        
        super().__init__()
        num_patches = (img_size // patch_size) ** 2
        self.patch_embed = nn.Conv2d(
            in_channels=3,
            out_channels=hidden_dim,
            kernel_size=patch_size,
            stride=patch_size
        )
        
        # 分类token
        self.cls_token = nn.Parameter(torch.zeros(1, 1, hidden_dim))
        
        # Transformer编码器
        self.encoder = Encoder(
            num_layers=num_layers,
            hidden_dim=hidden_dim,
            mlp_dim=mlp_dim,
            num_heads=num_heads,
            dropout=dropout,
            attn_dropout=attn_dropout
        )
        
        # 分类头
        self.pre_logits = nn.Identity()
        if representation_size:
            self.pre_logits = nn.Sequential(
                nn.Linear(hidden_dim, representation_size),
                nn.Tanh()
            )
            hidden_dim = representation_size
            
        self.head = nn.Linear(hidden_dim, num_classes)
        
        # 初始化
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        self._init_weights()

    def _init_weights(self):
        # 卷积层初始化
        nn.init.xavier_uniform_(self.patch_embed.weight)
        nn.init.normal_(self.patch_embed.bias, std=1e-6)
        
        # 分类头初始化
        nn.init.zeros_(self.head.weight)
        nn.init.constant_(self.head.bias, 0)

    def forward(self, x):
        # 分块嵌入 [B, C, H, W] -> [B, hidden_dim, grid, grid]
        x = self.patch_embed(x)  
        B, C, H, W = x.shape
        
        # 展平并转置 [B, C, H*W] -> [B, H*W, C]
        x = x.flatten(2).transpose(1, 2)  
        
        # 添加分类token
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        
        # Transformer编码
        x = self.encoder(x, self.training)
        
        # 分类
        x = x[:, 0]  # 取分类token
        x = self.pre_logits(x)
        return self.head(x)

## Training

In [8]:
def train(net,file_name):
    # the network optimizer
    optimizer = getattr(optim, optim_name)(net.parameters(), **optim_kwargs)

    # loss function
    criterion = nn.CrossEntropyLoss()

    # training loop
    # training loop
    net.train()

    outputs = []

    for epoch in range(num_epochs):
        epoch_t = 0
        running_loss = 0.0
        for i, (img, target) in enumerate(loader["train"]):
            s = time()
            img, target = img.to(device), target.to(device)

            pred = net(img)
            loss = criterion(pred, target)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            e = time()
            epoch_t += e - s
            if i ==0:
                output = f"time: {e-s:.3f} seconds"
                print(output)
            

            # print statistics
            running_loss += loss.item()
            if i % print_every == print_every - 1:
                output = f"[epoch={epoch + 1:3d}, iter={i + 1:5d}] loss: {running_loss / print_every:.3f} epoch time: {epoch_t:.3f} seconds"
                print(output)
                outputs.append(output)
                
                with open(file_name, "w") as f:
                    for out in outputs:
                        f.write(out + "\n")
                running_loss = 0.0
                epoch_t = 0
    print("Finished Training")
    return outputs


## Evaluating its accuracy

In [9]:
def evaluate_accuracy(net):
    net.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for img, target in loader["test"]:
            img, target = img.to(device), target.to(device)
            
            # make prediction
            pred = net(img)
            
            # accumulate
            total += len(target)
            correct += (torch.argmax(pred, dim=1) == target).sum().item()

    output = f"Accuracy of the network on the {total} test images: {100 * correct / total:.2f}%"

    print(output)
    return  output

## run

In [10]:
net = VisionTransformer(
                        num_classes=10,
                            img_size=32,
                            patch_size=4,
                            hidden_dim=192,
                            num_layers=6,
                            num_heads=6,
                            mlp_dim=768,
                            dropout=0.1,
                            attn_dropout=0.1,
                            representation_size=None)

# move to device
net.to(device)

# print the number of parameters
print(f"number of parameters: {sum(p.numel() for p in net.parameters() if p.requires_grad) / 1_000_000:.2f}M")


file_name = "Vit_output.txt"
ops = train(net,file_name)
ops.append(evaluate_accuracy(net))
with open(file_name, "w") as f:
    for out in ops:
        f.write(out + "\n")
 
# TODO： 1. adamw 2 warmup 3 consine 4 big batch size 



number of parameters: 2.69M
time: 0.561 seconds
[epoch=  1, iter=  200] loss: 2.009 epoch time: 3.488 seconds
[epoch=  1, iter=  400] loss: 1.822 epoch time: 3.248 seconds
[epoch=  1, iter=  600] loss: 1.739 epoch time: 3.616 seconds
time: 0.023 seconds
[epoch=  2, iter=  200] loss: 1.642 epoch time: 3.472 seconds
[epoch=  2, iter=  400] loss: 1.576 epoch time: 3.538 seconds
[epoch=  2, iter=  600] loss: 1.521 epoch time: 3.918 seconds
time: 0.022 seconds
[epoch=  3, iter=  200] loss: 1.470 epoch time: 3.121 seconds
[epoch=  3, iter=  400] loss: 1.448 epoch time: 3.382 seconds
[epoch=  3, iter=  600] loss: 1.408 epoch time: 3.691 seconds
time: 0.029 seconds
[epoch=  4, iter=  200] loss: 1.374 epoch time: 3.230 seconds
[epoch=  4, iter=  400] loss: 1.373 epoch time: 3.044 seconds
[epoch=  4, iter=  600] loss: 1.360 epoch time: 3.083 seconds
time: 0.021 seconds
[epoch=  5, iter=  200] loss: 1.320 epoch time: 3.164 seconds
[epoch=  5, iter=  400] loss: 1.318 epoch time: 3.287 seconds
[epo

In [None]:
layer_dict = {
    0: {'in_channels': 3, 'out_channels': 128, 'kernel_size': 3, 'Is_reg':False, 'Is_BN': True, 'stride': 1, 'padding': 1, },
    1: {'in_channels': 128, 'out_channels': 256, 'kernel_size': 3, 'Is_reg': False, 'Is_BN': True, 'stride': 1, 'padding': 1, },
    2: {'in_channels': 256, 'out_channels': 512, 'kernel_size': 3, 'Is_reg': False, 'Is_BN': True, 'stride': 1, 'padding': 1, },
    3: {'in_channels': 512, 'out_channels': 1024, 'kernel_size': 3, 'Is_reg': False, 'Is_BN': True, 'stride': 1, 'padding': 1, },
    4: {'in_channels': 1024, 'out_channels': 512, 'kernel_size': 3, 'Is_reg': True, 'Is_BN': True, 'stride': 1, 'padding': 1, },
    5: {'in_channels': 512, 'out_channels': 256, 'kernel_size': 3, 'Is_reg': True, 'Is_BN': True, 'stride': 1, 'padding': 1, },
    6: {'in_channels': 256, 'out_channels': 128, 'kernel_size': 3, 'Is_reg': True, 'Is_BN': True, 'stride': 1, 'padding': 1, },
}
net  =nn.Sequential(
ConvBs(7, layer_dict),
nn.Flatten(),
nn.Linear(128 * 4 * 4, 256), nn.ReLU(inplace=True), nn.Dropout(0.5),
nn.Linear(256, 128), nn.ReLU(inplace=True), nn.Dropout(0.5),
nn.Linear(128, 10),
)
# move to device
net.to(device)

# print the number of parameters
print(f"number of parameters: {sum(p.numel() for p in net.parameters() if p.requires_grad) / 1_000_000:.2f}M")



file_name = "BN_output.txt"
ops = train(net,file_name)
ops.append(evaluate_accuracy(net))
with open(file_name, "w") as f:
    for out in ops:
        f.write(out + "\n")

number of parameters: 12.96M
time: 0.091 seconds
[epoch=  1, iter=  200] loss: 1.952 epoch time: 1.227 seconds
[epoch=  1, iter=  400] loss: 1.670 epoch time: 1.116 seconds
[epoch=  1, iter=  600] loss: 1.515 epoch time: 1.116 seconds
time: 0.009 seconds


In [10]:
# our network architecture

net = nn.Sequential(
    nn.Conv2d(3, 128, 3, padding=1), nn.ReLU(inplace=True), nn.MaxPool2d(2), nn.Dropout(0.3),
    nn.Conv2d(128, 256, 3, padding=1), nn.ReLU(inplace=True), nn.MaxPool2d(2), nn.Dropout(0.3),
    nn.Conv2d(256, 512, 3, padding=1), nn.ReLU(inplace=True),
    nn.Conv2d(512, 512, 3, padding=1), nn.ReLU(inplace=True),
    nn.Conv2d(512, 256, 3, padding=1), nn.ReLU(inplace=True), nn.MaxPool2d(2), nn.Dropout(0.3),
    nn.Flatten(),
    nn.Linear(256 * 4 * 4, 512), nn.ReLU(inplace=True), nn.Dropout(0.5),
    nn.Linear(512, 256), nn.ReLU(inplace=True), nn.Dropout(0.5),
    nn.Linear(256, 128), nn.ReLU(inplace=True), nn.Dropout(0.5),
    nn.Linear(128, 10),
)

# move to device
net.to(device)

# print the number of parameters
print(f"number of parameters: {sum(p.numel() for p in net.parameters() if p.requires_grad) / 1_000_000:.2f}M")


file_name = "init_output.txt"
ops = train(net,file_name)
ops.append(evaluate_accuracy(net))
with open(file_name, "w") as f:
    for out in ops:
        f.write(out + "\n")

number of parameters: 7.28M
time: 0.389 seconds
[epoch=  1, iter=  200] loss: 2.198 epoch time: 1.002 seconds
[epoch=  1, iter=  400] loss: 1.972 epoch time: 0.591 seconds
[epoch=  1, iter=  600] loss: 1.856 epoch time: 0.673 seconds
time: 0.008 seconds
[epoch=  2, iter=  200] loss: 1.660 epoch time: 0.626 seconds
[epoch=  2, iter=  400] loss: 1.603 epoch time: 0.613 seconds
[epoch=  2, iter=  600] loss: 1.537 epoch time: 0.629 seconds
time: 0.007 seconds
[epoch=  3, iter=  200] loss: 1.467 epoch time: 0.642 seconds
[epoch=  3, iter=  400] loss: 1.405 epoch time: 0.662 seconds
[epoch=  3, iter=  600] loss: 1.385 epoch time: 0.675 seconds
time: 0.007 seconds
[epoch=  4, iter=  200] loss: 1.290 epoch time: 0.678 seconds
[epoch=  4, iter=  400] loss: 1.267 epoch time: 0.665 seconds
[epoch=  4, iter=  600] loss: 1.255 epoch time: 0.693 seconds
time: 0.008 seconds
[epoch=  5, iter=  200] loss: 1.185 epoch time: 0.687 seconds
[epoch=  5, iter=  400] loss: 1.187 epoch time: 0.731 seconds
[epo