In [None]:
!pip install wandb

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
import wandb

In [None]:
wandb.init(project="17-Flowers")

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = torchvision.models.resnet50(pretrained=True)

In [None]:
# Transfer Learning
in_features = model.fc.in_features
model.fc = nn.Linear(in_features, 17)

In [None]:
# for param in model.parameters():
#     param.requires_grad = False

# This freezes layers 1-6 in the total 10 layers of Resnet50
ct = 0
for child in model.children():
    ct += 1
    if ct < 7:
        for param in child.parameters():
            param.requires_grad = False

In [None]:
model = model.to(device)

In [None]:
config = wandb.config
config.learning_rate = 0.01
config.batch_size = 32
config.epochs = 15

In [None]:
# Dataset
transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(20),
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

dataset = torchvision.datasets.ImageFolder(root="/content/drive/MyDrive/datasets/Flowers/Train", transform=transform)
train_data_loader = torch.utils.data.DataLoader(dataset, batch_size=config.batch_size, shuffle=True, num_workers=2)

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)
loss_function = torch.nn.CrossEntropyLoss()

In [None]:
def calc_acc(preds, labels):
    preds_max = torch.argmax(preds, 1)
    acc = torch.sum(preds_max == labels.data, dtype=torch.float64) / len(preds)
    return acc

In [None]:
# Train
wandb.watch(model)

for epoch in range(config.epochs):
    train_loss = 0.0
    train_acc = 0.0
    for images, labels in train_data_loader:
        images = images.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        # 1- forwarding
        preds = model(images)
        # 2- backwarding 
        loss = loss_function(preds, labels)
        loss.backward()
        # 3- Update
        optimizer.step()

        train_loss += loss
        train_acc += calc_acc(preds, labels)
    
    total_loss = train_loss / len(train_data_loader)
    total_acc = train_acc / len(train_data_loader)

    if epoch % 2 == 0:
        wandb.log({"loss": total_loss})
        wandb.log({"acc": total_acc})

    print(f"Epoch: {epoch}, Loss: {total_loss}, Acc: {total_acc}")