In [1]:
!pip install wandb

Collecting wandb
  Downloading wandb-0.12.9-py2.py3-none-any.whl (1.7 MB)
[?25l[K     |▏                               | 10 kB 23.7 MB/s eta 0:00:01[K     |▍                               | 20 kB 27.8 MB/s eta 0:00:01[K     |▋                               | 30 kB 16.9 MB/s eta 0:00:01[K     |▊                               | 40 kB 12.0 MB/s eta 0:00:01[K     |█                               | 51 kB 5.8 MB/s eta 0:00:01[K     |█▏                              | 61 kB 6.2 MB/s eta 0:00:01[K     |█▍                              | 71 kB 5.4 MB/s eta 0:00:01[K     |█▌                              | 81 kB 6.1 MB/s eta 0:00:01[K     |█▊                              | 92 kB 6.4 MB/s eta 0:00:01[K     |██                              | 102 kB 5.3 MB/s eta 0:00:01[K     |██                              | 112 kB 5.3 MB/s eta 0:00:01[K     |██▎                             | 122 kB 5.3 MB/s eta 0:00:01[K     |██▌                             | 133 kB 5.3 MB/s eta 0:00:01

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
import wandb

In [None]:
wandb.init(project="Persian-MNIST-TransferLearning-by-Torch")

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = torchvision.models.resnet50(pretrained=True)

## Transfer Learning

In [5]:
in_features = model.fc.in_features
model.fc = nn.Linear(in_features, 10)

This freezes layers 1-6 in the total 10 layers of Resnet50

In [8]:
ct = 0
for child in model.children():
    ct += 1
    if ct < 7:
        for param in child.parameters():
            param.requires_grad = False

In [9]:
model = model.to(device)

In [10]:
config = wandb.config
config.learning_rate = 0.001
config.batch_size = 64
config.epochs = 20

Dataset

In [12]:
transform = transforms.Compose([
        transforms.RandomRotation(10),
        transforms.Resize((70, 70)),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

dataset = torchvision.datasets.ImageFolder(root='/content/drive/MyDrive/Datasets/MNIST_persian', transform=transform)
train_data_loader = torch.utils.data.DataLoader(dataset, batch_size=config.batch_size, shuffle=True, num_workers=2)

In [13]:
optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)
loss_function = torch.nn.CrossEntropyLoss()

In [14]:
def calc_acc(preds, labels):
    preds_max = torch.argmax(preds, 1)
    acc = torch.sum(preds_max == labels.data, dtype=torch.float64) / len(preds)
    return acc

In [None]:
model.train()

## Train

In [16]:
wandb.watch(model)

for epoch in range(config.epochs):
    train_loss = 0.0
    train_acc = 0.0
    for images, labels in train_data_loader:
        images = images.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        # 1- forwarding
        preds = model(images)
        # 2- backwarding 
        loss = loss_function(preds, labels)
        loss.backward()
        # 3- Update
        optimizer.step()

        train_loss += loss
        train_acc += calc_acc(preds, labels)
    
    total_loss = train_loss / len(train_data_loader)
    total_acc = train_acc / len(train_data_loader)

    if epoch % 2 == 0:
        wandb.log({"loss": total_loss})
        wandb.log({"acc": total_acc})

    print(f"Epoch: {epoch}, Loss: {total_loss}, Acc: {total_acc}")

Epoch: 0, Loss: 0.8118663430213928, Acc: 0.7376644736842105
Epoch: 1, Loss: 0.20373013615608215, Acc: 0.9465460526315789
Epoch: 2, Loss: 0.10911434888839722, Acc: 0.9703947368421052
Epoch: 3, Loss: 0.0627252459526062, Acc: 0.9805372807017543
Epoch: 4, Loss: 0.10923654586076736, Acc: 0.972861842105263
Epoch: 5, Loss: 0.09046061336994171, Acc: 0.9810855263157894
Epoch: 6, Loss: 0.08001021295785904, Acc: 0.9766995614035088
Epoch: 7, Loss: 0.058515530079603195, Acc: 0.9860197368421052
Epoch: 8, Loss: 0.04184062033891678, Acc: 0.9868421052631579
Epoch: 9, Loss: 0.03258498013019562, Acc: 0.9931469298245614
Epoch: 10, Loss: 0.08805110305547714, Acc: 0.9794407894736842
Epoch: 11, Loss: 0.054632022976875305, Acc: 0.9857456140350878
Epoch: 12, Loss: 0.04272497445344925, Acc: 0.9893092105263157
Epoch: 13, Loss: 0.03697117418050766, Acc: 0.9901315789473684
Epoch: 14, Loss: 0.054149381816387177, Acc: 0.9849232456140351
Epoch: 15, Loss: 0.025516046211123466, Acc: 0.9923245614035088
Epoch: 16, Loss: 