In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.optim.lr_scheduler import CosineAnnealingLR

In [2]:
class FSAM:
    def __init__(self, rho = 0.05, sigma = 1, momentum = 0.6, lr = 0.05):
        self.rho = rho
        self.sigma = sigma
        self.momentum = momentum
        self.mt = None
        self.lr = lr
        
    def train_model(self, model, inputs, labels, loss_fn):
        outputs = model(inputs)
        loss = loss_fn(outputs, labels)
        loss.backward()
        
        g_t = [p.grad.detach() for p in model.parameters() if p.grad is not None]
        
        # Tinh m_t
        if self.mt is None:
            self.mt = [(1 - self.momentum) * g for g in g_t]
        else:
            self.mt = [self.momentum * m + (1 - self.momentum) * g for m, g in zip(self.mt, g_t)]
        
        # Tinh adversarial pertubation
        d_t = [g - self.sigma * m for g, m in zip(g_t, self.mt)]
        d_t_norm = torch.norm(torch.stack([torch.norm(d) for d in d_t]))
        epsilon_t = [self.rho * d / (d_t_norm + 1e-12) for d in d_t]
        
        # Tinh gradient approximation
        with torch.no_grad():
            for p, epsilon in zip(model.parameters(), epsilon_t):
                if p.grad is not None:
                    p.add_(epsilon)
        model.zero_grad()
        
        pertubation_output = model(inputs)
        pertubation_loss = loss_fn(pertubation_output, labels)
        pertubation_loss.backward()
        app_gt = [p.grad.detach() for p in model.parameters() if p.grad is not None]
        
        # Khoi phuc tham so
        with torch.no_grad():
            for p, epsilon in zip(model.parameters(), epsilon_t):
                if p.grad is not None:
                    p.sub_(epsilon)
        
        # Cap nhat tham so
        with torch.no_grad():
            for p, app_g in zip(model.parameters(), app_gt):
                if p.grad is not None:
                    p.sub_(self.lr * app_g)
        
        model.zero_grad()

In [3]:
# CIFAR-10 dataset

train_data = torchvision.datasets.CIFAR10(
    root = './data',
    train = True,
    download = True,
    transform = transforms.ToTensor()
)
train_dataloader = torch.utils.data.DataLoader(
    train_data, 
    batch_size = 128,
    shuffle = True
)

test_data = torchvision.datasets.CIFAR10(
    root = './data',
    train = False,
    download = True,
    transform = transforms.ToTensor()
)
test_dataloader = torch.utils.data.DataLoader(
    test_data,
    batch_size = 128,
    shuffle = False
)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:03<00:00, 48517352.52it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [4]:
# Khoi tao model
from torchvision.models import resnet18, ResNet18_Weights

model = torchvision.models.resnet18(weights = ResNet18_Weights.DEFAULT)
loss_fn = nn.CrossEntropyLoss()
optimizer = FSAM()

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 158MB/s]


In [5]:
for epoch in range(10):
    model.train()
    running_loss = 0 
    for i, (inputs, labels) in enumerate(train_dataloader, 0):
        
        optimizer.train_model(model, inputs, labels, loss_fn)
        
        running_loss += loss_fn(model(inputs), labels).item()
        if i % 100 == 99:                # 100 minibatch
            print(f'[Epoch {epoch + 1}, Batch {i + 1}] loss: {running_loss / 100:.3f}')
            running_loss = 0.0

print('Finished Training')

# Save model
torch.save(model.state_dict(), 'cifar10_fsam.pth')
    

[Epoch 1, Batch 100] loss: 1.099
[Epoch 1, Batch 200] loss: 0.557
[Epoch 1, Batch 300] loss: 0.467
[Epoch 2, Batch 100] loss: 0.335
[Epoch 2, Batch 200] loss: 0.318
[Epoch 2, Batch 300] loss: 0.303
[Epoch 3, Batch 100] loss: 0.224
[Epoch 3, Batch 200] loss: 0.207
[Epoch 3, Batch 300] loss: 0.212
[Epoch 4, Batch 100] loss: 0.144
[Epoch 4, Batch 200] loss: 0.143
[Epoch 4, Batch 300] loss: 0.148
[Epoch 5, Batch 100] loss: 0.104
[Epoch 5, Batch 200] loss: 0.093
[Epoch 5, Batch 300] loss: 0.101
[Epoch 6, Batch 100] loss: 0.069
[Epoch 6, Batch 200] loss: 0.063
[Epoch 6, Batch 300] loss: 0.068
[Epoch 7, Batch 100] loss: 0.047
[Epoch 7, Batch 200] loss: 0.043
[Epoch 7, Batch 300] loss: 0.048
[Epoch 8, Batch 100] loss: 0.033
[Epoch 8, Batch 200] loss: 0.030
[Epoch 8, Batch 300] loss: 0.032
[Epoch 9, Batch 100] loss: 0.023
[Epoch 9, Batch 200] loss: 0.024
[Epoch 9, Batch 300] loss: 0.022
[Epoch 10, Batch 100] loss: 0.016
[Epoch 10, Batch 200] loss: 0.018
[Epoch 10, Batch 300] loss: 0.016
Finishe

In [6]:
# Test model

model.eval()
correct, total = 0.0, 0.0
with torch.no_grad():
    for inputs, labels in test_dataloader:
        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy of the network on the 10000 test images: {100 * correct / total:.2f}%')

Accuracy of the network on the 10000 test images: 83.36%
