In [None]:
!pip install wandb

In [110]:
import torch, torchvision, wandb
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from tqdm import tqdm

🔸 wandb config

In [111]:
configs = {
            "learning_rate": 0.001,
            "epochs": 20,
            "batch_size": 64,
           }

wandb.init(project="Persian-MNIST-by-Torch", config=configs)
config = wandb.config

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
epochs,▁▁▂▂▂▃▃▄▄▄▅▅▅▆▆▇▇▇██▁▁▂▂▂▃▃▄▄▄▅▅▅▆▆▇▇▇██
train_acc,▁▄▅▅▅▅▅▆▆▆▆▆▆▆▆▆▇▇▇▇▁▃▅▅▆▇▇▇▇▇▇█████████
train_loss,█▆▄▄▄▄▄▄▃▃▃▃▃▃▃▃▃▂▂▂█▆▄▄▃▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁

0,1
epochs,19.0
train_acc,0.96436
train_loss,1.49765


In [112]:
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, (3, 3), (1, 1), (1, 1))
        self.conv2 = nn.Conv2d(32, 32, (3, 3), (1, 1), (1, 1))
        self.conv3 = nn.Conv2d(32, 64, (3, 3), (1, 1), (1, 1))
        self.conv4 = nn.Conv2d(64, 128, (3, 3), (1, 1), (1, 1))

        self.fc1 = nn.Linear(128*8*8, 512)
        self.fc2 = nn.Linear(512, 10)


    def forward(self, x):
      x = F.relu(self.conv1(x))
      x = F.max_pool2d(x, kernel_size=(2, 2))
      x = F.relu(self.conv2(x))
      x = F.max_pool2d(x, kernel_size=(2, 2))
      x = F.relu(self.conv3(x))
      x = F.max_pool2d(x, kernel_size=(2, 2))
      x = F.relu(self.conv4(x))
      x = torch.flatten(x, start_dim=1)
      x = F.relu(self.fc1(x))
      x = torch.flatten(x, start_dim=1)
      x = torch.dropout(x, 0.2, train=True)
      x = self.fc2(x)
      x = torch.softmax(x, dim=1)

      return x

In [113]:
batch = 32
epoch = 20
lr = 0.001
config.learning_rate = lr

In [114]:
device = torch.device("cuda")
model = Model().to(device)

In [115]:
data_transform = transforms.Compose([
                                     transforms.RandomRotation(10),
                                     transforms.Resize((70, 70)),
                                     transforms.ToTensor(),
                                     transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
                                     ])

dataset = torchvision.datasets.ImageFolder(root='/content/drive/MyDrive/Datasets/MNIST_persian', transform = data_transform)
dataset = torch.utils.data.DataLoader(dataset, batch_size = batch, shuffle = True)

In [116]:
optimizer = torch.optim.Adam(model.parameters(), lr = lr)
loss_func = nn.CrossEntropyLoss()

In [117]:
def calc_acc(y_hat,labels):
    _, y_hat_max = torch.max(y_hat,1)
    acc=torch.sum(y_hat_max == labels.data,dtype = torch.float64) / len(y_hat)
    return acc


In [118]:
model.train()

Model(
  (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv4): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (fc1): Linear(in_features=8192, out_features=512, bias=True)
  (fc2): Linear(in_features=512, out_features=10, bias=True)
)

In [119]:
for ep in range(epoch):
    train_loss = 0.0
    train_acc = 0.0

    for im, labels in tqdm(dataset):
        im = im.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        #forwarding
        y_hat = model(im)

        #backwarding
        loss = loss_func(y_hat,labels)
        loss.backward()

        #update
        optimizer.step()

        train_loss += loss
        train_acc += calc_acc(y_hat,labels)

    total_loss = train_loss/len(dataset)
    total_acc = train_acc/len(dataset)

    print(f"epoch:{ep} , Loss:{total_loss} , accuracy: {total_acc}")

    wandb.log({'epochs': ep,
               'train_acc': total_acc,
               'train_loss': total_loss})

100%|██████████| 38/38 [00:03<00:00, 10.53it/s]


epoch:0 , Loss:2.294548749923706 , accuracy: 0.14720394736842105


100%|██████████| 38/38 [00:03<00:00, 11.02it/s]


epoch:1 , Loss:1.9585835933685303 , accuracy: 0.5139802631578947


100%|██████████| 38/38 [00:03<00:00, 10.81it/s]


epoch:2 , Loss:1.8185824155807495 , accuracy: 0.6447368421052632


100%|██████████| 38/38 [00:03<00:00, 10.85it/s]


epoch:3 , Loss:1.7782381772994995 , accuracy: 0.6800986842105263


100%|██████████| 38/38 [00:03<00:00, 10.76it/s]


epoch:4 , Loss:1.7698974609375 , accuracy: 0.6907894736842105


100%|██████████| 38/38 [00:03<00:00, 10.98it/s]


epoch:5 , Loss:1.7398816347122192 , accuracy: 0.7245065789473684


100%|██████████| 38/38 [00:03<00:00, 10.95it/s]


epoch:6 , Loss:1.6956721544265747 , accuracy: 0.7648026315789473


100%|██████████| 38/38 [00:03<00:00, 10.98it/s]


epoch:7 , Loss:1.6555861234664917 , accuracy: 0.8083881578947368


100%|██████████| 38/38 [00:03<00:00, 10.79it/s]


epoch:8 , Loss:1.5970768928527832 , accuracy: 0.8692434210526315


100%|██████████| 38/38 [00:03<00:00, 11.02it/s]


epoch:9 , Loss:1.6021983623504639 , accuracy: 0.8585526315789473


100%|██████████| 38/38 [00:03<00:00, 11.04it/s]


epoch:10 , Loss:1.5826966762542725 , accuracy: 0.8774671052631579


100%|██████████| 38/38 [00:03<00:00, 10.98it/s]


epoch:11 , Loss:1.5701498985290527 , accuracy: 0.8873355263157894


100%|██████████| 38/38 [00:03<00:00, 10.86it/s]


epoch:12 , Loss:1.5432857275009155 , accuracy: 0.919407894736842


100%|██████████| 38/38 [00:03<00:00, 10.97it/s]


epoch:13 , Loss:1.528116226196289 , accuracy: 0.9342105263157894


100%|██████████| 38/38 [00:03<00:00, 10.92it/s]


epoch:14 , Loss:1.5245059728622437 , accuracy: 0.9391447368421052


100%|██████████| 38/38 [00:03<00:00, 10.66it/s]


epoch:15 , Loss:1.5250413417816162 , accuracy: 0.9383223684210525


100%|██████████| 38/38 [00:03<00:00, 11.06it/s]


epoch:16 , Loss:1.520359754562378 , accuracy: 0.9391447368421052


100%|██████████| 38/38 [00:03<00:00, 10.95it/s]


epoch:17 , Loss:1.5114070177078247 , accuracy: 0.9498355263157894


100%|██████████| 38/38 [00:03<00:00, 11.05it/s]


epoch:18 , Loss:1.510176658630371 , accuracy: 0.9498355263157894


100%|██████████| 38/38 [00:03<00:00, 10.99it/s]

epoch:19 , Loss:1.4983556270599365 , accuracy: 0.9646381578947368





In [120]:
torch.save(model.state_dict(), "persian-mnist.pth")