In [1]:
import os
import torch
from torch import nn, optim, utils
import torch.nn.functional as F
import torchvision
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
import swanlab

# CNNÁΩëÁªúÊûÑÂª∫
class ConvNet(nn.Module):
    def __init__(self):
        super().__init__()
        # 1,28x28
        self.conv1 = nn.Conv2d(1, 10, 5)  # 10, 24x24
        self.conv2 = nn.Conv2d(10, 20, 3)  # 128, 10x10
        self.fc1 = nn.Linear(20 * 10 * 10, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        in_size = x.size(0)
        out = self.conv1(x)  # 24
        out = F.relu(out)
        out = F.max_pool2d(out, 2, 2)  # 12
        out = self.conv2(out)  # 10
        out = F.relu(out)
        out = out.view(in_size, -1)
        out = self.fc1(out)
        out = F.relu(out)
        out = self.fc2(out)
        out = F.log_softmax(out, dim=1)
        return out


# ÊçïËé∑Âπ∂ÂèØËßÜÂåñÂâç20Âº†ÂõæÂÉè
def log_images(loader, num_images=16):
    images_logged = 0
    logged_images = []
    for images, labels in loader:
        # images: batch of images, labels: batch of labels
        for i in range(images.shape[0]):
            if images_logged < num_images:
                # ‰ΩøÁî®swanlab.ImageÂ∞ÜÂõæÂÉèËΩ¨Êç¢‰∏∫wandbÂèØËßÜÂåñÊ†ºÂºè
                logged_images.append(swanlab.Image(images[i], caption=f"Label: {labels[i]}"))
                images_logged += 1
            else:
                break
        if images_logged >= num_images:
            break
    swanlab.log({"MNIST-Preview": logged_images})
    

def train(model, device, train_dataloader, optimizer, criterion, epoch, num_epochs):
    model.train()
    # 1. Âæ™ÁéØË∞ÉÁî®train_dataloaderÔºåÊØèÊ¨°ÂèñÂá∫1‰∏™batch_sizeÁöÑÂõæÂÉèÂíåÊ†áÁ≠æ
    for iter, (inputs, labels) in enumerate(train_dataloader):
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        # 2. ‰º†ÂÖ•Âà∞resnet18Ê®°Âûã‰∏≠ÂæóÂà∞È¢ÑÊµãÁªìÊûú
        outputs = model(inputs)
        # 3. Â∞ÜÁªìÊûúÂíåÊ†áÁ≠æ‰º†ÂÖ•ÊçüÂ§±ÂáΩÊï∞‰∏≠ËÆ°ÁÆó‰∫§ÂèâÁÜµÊçüÂ§±
        loss = criterion(outputs, labels)
        # 4. Ê†πÊçÆÊçüÂ§±ËÆ°ÁÆóÂèçÂêë‰º†Êí≠
        loss.backward()
        # 5. ‰ºòÂåñÂô®ÊâßË°åÊ®°ÂûãÂèÇÊï∞Êõ¥Êñ∞
        optimizer.step()
        print('Epoch [{}/{}], Iteration [{}/{}], Loss: {:.4f}'.format(epoch, num_epochs, iter + 1, len(train_dataloader),
                                                                      loss.item()))
        # 6. ÊØè20Ê¨°Ëø≠‰ª£ÔºåÁî®SwanLabËÆ∞ÂΩï‰∏Ä‰∏ãlossÁöÑÂèòÂåñ
        if iter % 20 == 0:
            swanlab.log({"train/loss": loss.item()})

def test(model, device, val_dataloader, epoch):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        # 1. Âæ™ÁéØË∞ÉÁî®val_dataloaderÔºåÊØèÊ¨°ÂèñÂá∫1‰∏™batch_sizeÁöÑÂõæÂÉèÂíåÊ†áÁ≠æ
        for inputs, labels in val_dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            # 2. ‰º†ÂÖ•Âà∞resnet18Ê®°Âûã‰∏≠ÂæóÂà∞È¢ÑÊµãÁªìÊûú
            outputs = model(inputs)
            # 3. Ëé∑ÂæóÈ¢ÑÊµãÁöÑÊï∞Â≠ó
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            # 4. ËÆ°ÁÆó‰∏éÊ†áÁ≠æ‰∏ÄËá¥ÁöÑÈ¢ÑÊµãÁªìÊûúÁöÑÊï∞Èáè
            correct += (predicted == labels).sum().item()
    
        # 5. ÂæóÂà∞ÊúÄÁªàÁöÑÊµãËØïÂáÜÁ°ÆÁéá
        accuracy = correct / total
        # 6. Áî®SwanLabËÆ∞ÂΩï‰∏Ä‰∏ãÂáÜÁ°ÆÁéáÁöÑÂèòÂåñ
        swanlab.log({"val/accuracy": accuracy}, step=epoch)
    

if __name__ == "__main__":

    #Ê£ÄÊµãÊòØÂê¶ÊîØÊåÅmps
    try:
        use_mps = torch.backends.mps.is_available()
    except AttributeError:
        use_mps = False

    #Ê£ÄÊµãÊòØÂê¶ÊîØÊåÅcuda
    if torch.cuda.is_available():
        device = "cuda"
    elif use_mps:
        device = "mps"
    else:
        device = "cpu"

    # ÂàùÂßãÂåñswanlab
    run = swanlab.init(
        project="MNIST-example",
        experiment_name="PlainCNN",
        config={
            "model": "ResNet18",
            "optim": "Adam",
            "lr": 1e-4,
            "batch_size": 256,
            "num_epochs": 10,
            "device": device,
        },
    )

    # ËÆæÁΩÆMNISTËÆ≠ÁªÉÈõÜÂíåÈ™åËØÅÈõÜ
    dataset = MNIST(os.getcwd(), train=True, download=True, transform=ToTensor())
    train_dataset, val_dataset = utils.data.random_split(dataset, [55000, 5000])

    train_dataloader = utils.data.DataLoader(train_dataset, batch_size=run.config.batch_size, shuffle=True)
    val_dataloader = utils.data.DataLoader(val_dataset, batch_size=8, shuffle=False)
    
    # ÔºàÂèØÈÄâÔºâÁúã‰∏Ä‰∏ãÊï∞ÊçÆÈõÜÁöÑÂâç16Âº†ÂõæÂÉè
    log_images(train_dataloader, 16)

    # ÂàùÂßãÂåñÊ®°Âûã
    model = ConvNet()
    model.to(torch.device(device))

    # ÊâìÂç∞Ê®°Âûã
    print(model)

    # ÂÆö‰πâÊçüÂ§±ÂáΩÊï∞Âíå‰ºòÂåñÂô®
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=run.config.lr)

    # ÂºÄÂßãËÆ≠ÁªÉÂíåÊµãËØïÂæ™ÁéØ
    for epoch in range(1, run.config.num_epochs+1):
        swanlab.log({"train/epoch": epoch}, step=epoch)
        train(model, device, train_dataloader, optimizer, criterion, epoch, run.config.num_epochs)
        if epoch % 2 == 0: 
            test(model, device, val_dataloader, epoch)

    # ‰øùÂ≠òÊ®°Âûã
    # Â¶ÇÊûú‰∏çÂ≠òÂú®checkpointÊñá‰ª∂Â§πÔºåÂàôËá™Âä®ÂàõÂª∫‰∏Ä‰∏™
    if not os.path.exists("checkpoint"):
        os.makedirs("checkpoint")
    torch.save(model.state_dict(), 'checkpoint/latest_checkpoint.pth')

ConvNet(
  (conv1): Conv2d(1, 10, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(10, 20, kernel_size=(3, 3), stride=(1, 1))
  (fc1): Linear(in_features=2000, out_features=500, bias=True)
  (fc2): Linear(in_features=500, out_features=10, bias=True)
)
Epoch [1/10], Iteration [1/215], Loss: 2.3026
Epoch [1/10], Iteration [2/215], Loss: 2.2965
Epoch [1/10], Iteration [3/215], Loss: 2.2896
Epoch [1/10], Iteration [4/215], Loss: 2.2756
Epoch [1/10], Iteration [5/215], Loss: 2.2690
Epoch [1/10], Iteration [6/215], Loss: 2.2590
Epoch [1/10], Iteration [7/215], Loss: 2.2569
Epoch [1/10], Iteration [8/215], Loss: 2.2524
Epoch [1/10], Iteration [9/215], Loss: 2.2425
Epoch [1/10], Iteration [10/215], Loss: 2.2361
Epoch [1/10], Iteration [11/215], Loss: 2.2196
Epoch [1/10], Iteration [12/215], Loss: 2.2080
Epoch [1/10], Iteration [13/215], Loss: 2.2021
Epoch [1/10], Iteration [14/215], Loss: 2.1882
Epoch [1/10], Iteration [15/215], Loss: 2.1809
Epoch [1/10], Iteration [16/215], Loss: 2.1639
E

In [2]:
swanlab.finish()