In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm 

### #1: Dataset Class and Dataloader

In [2]:
class PalmsDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        self.data = ImageFolder(data_dir,transform=transform)
        self.transform=transform
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        return self.data[index]
    
    @property
    def classes(self):
        return self.data.classes

In [3]:
train_transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
])

In [None]:
train_set = PalmsDataset("train",transform=train_transform)

#### Dataset vs Dataloader

In [5]:
for image, label in train_set:
    break

image.shape, label

(torch.Size([3, 224, 224]), 0)

In [6]:
train_loader = DataLoader(batch_size=32, dataset=train_set, shuffle=True)

### #2: Classifier

In [None]:
class PalmDiseaseClassifier(nn.Module):
    def __init__(self, number_of_classes=9):
        super(PalmDiseaseClassifier, self).__init__()
        
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1, stride=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1, stride=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1, stride=1)
        
        self.bn1 = nn.BatchNorm2d(32)
        self.bn2 = nn.BatchNorm2d(64)
        self.bn3 = nn.BatchNorm2d(128)
        
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.drop1 = nn.Dropout(0.2)
        self.fc1 = nn.Linear(128*28*28, number_of_classes)
    
    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.pool(x)
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.pool(x)
        x = F.relu(self.bn3(self.conv3(x)))
        x = self.pool(x)
        # print(x.shape) # to get the needed input size for fc1
        
        x = x.view(x.size(0), -1)
        # print(x.shape)
        x = self.drop1(x)
        x = self.fc1(x)
        
        return x
        
        
        

In [None]:
model = PalmDiseaseClassifier()

In [9]:
# the needed x shape for fc1
# add print(x.shape) after the pool, then comment the fc1 and run the model on a batch
# the output will be [batch size, channel, highet, width]. the input for fc will be C*H*W
for x, y in train_loader:
    break
model(x)

tensor([[-7.0882e-02, -5.1109e-01,  7.1248e-01, -6.6899e-01,  2.6139e-01,
          1.6226e-01, -6.4960e-01,  1.0716e+00, -4.4262e-01],
        [-5.8032e-01, -1.0455e+00,  1.8994e-01, -2.8990e-01,  3.7907e-01,
         -6.9123e-01, -1.3212e-01, -8.6309e-02,  7.4228e-01],
        [-8.0366e-01, -5.7059e-02,  1.1858e-02, -3.1909e-01, -1.2748e-01,
          1.0562e+00, -1.0301e-01, -2.3994e-01,  2.0301e-01],
        [-1.3421e-01,  2.9368e-01,  7.9003e-01, -2.7535e-01, -3.2956e-01,
         -2.6141e-01,  1.6692e-01, -3.0124e-01, -7.2717e-01],
        [-3.5331e-01,  1.6447e-01,  3.0154e-01, -1.1321e+00,  2.6669e-01,
          4.8662e-01, -7.4506e-01,  4.9547e-01, -4.0434e-01],
        [-5.3650e-01, -1.4516e-01,  3.3962e-01,  2.7971e-01,  5.5072e-01,
          1.7346e-01, -8.1457e-03,  2.3874e-01,  3.8237e-01],
        [-6.7001e-01, -9.7767e-01,  3.0849e-01, -1.2333e+00, -4.7234e-01,
          8.2979e-02, -3.1822e-01,  1.0678e+00,  3.2823e-01],
        [-2.4868e-01, -2.9906e-01,  4.7976e-01, 

In [13]:
print(str(model))

PalmClassifier(
  (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (drop1): Dropout(p=0.2, inplace=False)
  (fc1): Linear(in_features=100352, out_features=9, bias=True)
)


In [14]:
transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor()
])

In [None]:
val_dataset = PalmsDataset("valid", transform=transform)
test_dataset = PalmsDataset("test", transform=transform)

val_loader = DataLoader(val_dataset,batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [None]:
device = torch.device("cuda:0")

In [None]:
model = PalmDiseaseClassifier(number_of_classes=9).to(device=device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)

In [None]:
def train_model(model, train_loader, val_loader, criterion, optimizer, epochs=25):
    best_acc = 0.0
    
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} (Train)"):
            images, labels = images.to(device), labels.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        
        train_loss = running_loss / len(train_loader)
        train_acc = 100 * correct / total
        
        val_loss, val_acc = validate(model, val_loader, criterion)
        
        print(f"Epoch {epoch+1}: "
              f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}% | "
              f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")
        
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(model.state_dict(), 'best_palm_cnn_temp.pth')
            print(f"Saved new best model (Acc: {best_acc:.2f}%)")

def validate(model, val_loader, criterion):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for images, labels in tqdm(val_loader, desc="Validating"):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    val_loss = running_loss / len(val_loader)
    val_acc = 100 * correct / total
    return val_loss, val_acc

train_model(model, train_loader, val_loader, criterion, optimizer, epochs=50)

Epoch 1/50 (Train): 100%|██████████| 64/64 [00:21<00:00,  2.91it/s]
Validating: 100%|██████████| 19/19 [00:05<00:00,  3.59it/s]


Epoch 1: Train Loss: 2.4977, Train Acc: 33.96% | Val Loss: 2.1183, Val Acc: 41.18%
Saved new best model (Acc: 41.18%)


Epoch 2/50 (Train): 100%|██████████| 64/64 [00:08<00:00,  7.15it/s]
Validating: 100%|██████████| 19/19 [00:01<00:00, 10.27it/s]


Epoch 2: Train Loss: 1.5807, Train Acc: 46.73% | Val Loss: 1.3979, Val Acc: 57.96%
Saved new best model (Acc: 57.96%)


Epoch 3/50 (Train): 100%|██████████| 64/64 [00:09<00:00,  6.89it/s]
Validating: 100%|██████████| 19/19 [00:01<00:00,  9.90it/s]


Epoch 3: Train Loss: 1.4632, Train Acc: 53.37% | Val Loss: 1.3712, Val Acc: 58.48%
Saved new best model (Acc: 58.48%)


Epoch 4/50 (Train): 100%|██████████| 64/64 [00:08<00:00,  7.12it/s]
Validating: 100%|██████████| 19/19 [00:01<00:00, 10.25it/s]


Epoch 4: Train Loss: 1.3810, Train Acc: 55.45% | Val Loss: 1.6363, Val Acc: 51.56%


Epoch 5/50 (Train): 100%|██████████| 64/64 [00:08<00:00,  7.13it/s]
Validating: 100%|██████████| 19/19 [00:01<00:00, 10.61it/s]


Epoch 5: Train Loss: 1.3421, Train Acc: 58.32% | Val Loss: 1.1761, Val Acc: 62.28%
Saved new best model (Acc: 62.28%)


Epoch 6/50 (Train): 100%|██████████| 64/64 [00:09<00:00,  7.07it/s]
Validating: 100%|██████████| 19/19 [00:01<00:00, 10.35it/s]


Epoch 6: Train Loss: 1.0942, Train Acc: 63.27% | Val Loss: 1.3756, Val Acc: 56.92%


Epoch 7/50 (Train): 100%|██████████| 64/64 [00:08<00:00,  7.18it/s]
Validating: 100%|██████████| 19/19 [00:01<00:00,  9.85it/s]


Epoch 7: Train Loss: 1.0392, Train Acc: 66.24% | Val Loss: 1.2590, Val Acc: 60.90%


Epoch 8/50 (Train): 100%|██████████| 64/64 [00:09<00:00,  7.05it/s]
Validating: 100%|██████████| 19/19 [00:01<00:00, 10.30it/s]


Epoch 8: Train Loss: 0.9593, Train Acc: 68.42% | Val Loss: 1.3257, Val Acc: 60.55%


Epoch 9/50 (Train): 100%|██████████| 64/64 [00:09<00:00,  7.00it/s]
Validating: 100%|██████████| 19/19 [00:01<00:00, 10.44it/s]


Epoch 9: Train Loss: 0.9958, Train Acc: 67.87% | Val Loss: 1.4412, Val Acc: 62.11%


Epoch 10/50 (Train): 100%|██████████| 64/64 [00:08<00:00,  7.15it/s]
Validating: 100%|██████████| 19/19 [00:01<00:00, 10.41it/s]


Epoch 10: Train Loss: 0.8637, Train Acc: 71.78% | Val Loss: 1.0932, Val Acc: 65.74%
Saved new best model (Acc: 65.74%)


Epoch 11/50 (Train): 100%|██████████| 64/64 [00:08<00:00,  7.26it/s]
Validating: 100%|██████████| 19/19 [00:01<00:00, 10.33it/s]


Epoch 11: Train Loss: 0.7217, Train Acc: 76.09% | Val Loss: 1.5649, Val Acc: 61.25%


Epoch 12/50 (Train): 100%|██████████| 64/64 [00:08<00:00,  7.12it/s]
Validating: 100%|██████████| 19/19 [00:01<00:00, 10.70it/s]


Epoch 12: Train Loss: 0.8017, Train Acc: 73.47% | Val Loss: 1.0519, Val Acc: 67.82%
Saved new best model (Acc: 67.82%)


Epoch 13/50 (Train): 100%|██████████| 64/64 [00:08<00:00,  7.28it/s]
Validating: 100%|██████████| 19/19 [00:01<00:00, 10.26it/s]


Epoch 13: Train Loss: 0.6989, Train Acc: 76.04% | Val Loss: 1.2065, Val Acc: 63.32%


Epoch 14/50 (Train): 100%|██████████| 64/64 [00:08<00:00,  7.23it/s]
Validating: 100%|██████████| 19/19 [00:01<00:00, 10.62it/s]


Epoch 14: Train Loss: 0.6247, Train Acc: 78.17% | Val Loss: 1.0535, Val Acc: 69.55%
Saved new best model (Acc: 69.55%)


Epoch 15/50 (Train): 100%|██████████| 64/64 [00:08<00:00,  7.11it/s]
Validating: 100%|██████████| 19/19 [00:01<00:00, 10.15it/s]


Epoch 15: Train Loss: 0.7585, Train Acc: 76.73% | Val Loss: 1.0458, Val Acc: 68.69%


Epoch 16/50 (Train): 100%|██████████| 64/64 [00:08<00:00,  7.30it/s]
Validating: 100%|██████████| 19/19 [00:01<00:00, 10.15it/s]


Epoch 16: Train Loss: 0.7545, Train Acc: 75.05% | Val Loss: 0.9797, Val Acc: 72.49%
Saved new best model (Acc: 72.49%)


Epoch 17/50 (Train): 100%|██████████| 64/64 [00:08<00:00,  7.31it/s]
Validating: 100%|██████████| 19/19 [00:01<00:00, 10.72it/s]


Epoch 17: Train Loss: 0.4924, Train Acc: 83.27% | Val Loss: 1.0222, Val Acc: 74.05%
Saved new best model (Acc: 74.05%)


Epoch 18/50 (Train): 100%|██████████| 64/64 [00:08<00:00,  7.18it/s]
Validating: 100%|██████████| 19/19 [00:01<00:00, 10.67it/s]


Epoch 18: Train Loss: 0.5767, Train Acc: 80.59% | Val Loss: 0.9255, Val Acc: 73.01%


Epoch 19/50 (Train): 100%|██████████| 64/64 [00:08<00:00,  7.20it/s]
Validating: 100%|██████████| 19/19 [00:01<00:00, 10.22it/s]


Epoch 19: Train Loss: 0.5236, Train Acc: 82.23% | Val Loss: 1.0579, Val Acc: 76.99%
Saved new best model (Acc: 76.99%)


Epoch 20/50 (Train): 100%|██████████| 64/64 [00:09<00:00,  7.03it/s]
Validating: 100%|██████████| 19/19 [00:01<00:00,  9.56it/s]


Epoch 20: Train Loss: 0.4812, Train Acc: 84.50% | Val Loss: 0.9338, Val Acc: 72.66%


Epoch 21/50 (Train): 100%|██████████| 64/64 [00:09<00:00,  7.06it/s]
Validating: 100%|██████████| 19/19 [00:01<00:00, 10.60it/s]


Epoch 21: Train Loss: 0.5837, Train Acc: 82.28% | Val Loss: 0.9185, Val Acc: 73.88%


Epoch 22/50 (Train): 100%|██████████| 64/64 [00:08<00:00,  7.24it/s]
Validating: 100%|██████████| 19/19 [00:01<00:00, 10.35it/s]


Epoch 22: Train Loss: 0.4657, Train Acc: 83.42% | Val Loss: 0.8666, Val Acc: 72.15%


Epoch 23/50 (Train): 100%|██████████| 64/64 [00:08<00:00,  7.37it/s]
Validating: 100%|██████████| 19/19 [00:01<00:00, 10.97it/s]


Epoch 23: Train Loss: 0.3863, Train Acc: 87.62% | Val Loss: 1.0726, Val Acc: 72.49%


Epoch 24/50 (Train): 100%|██████████| 64/64 [00:08<00:00,  7.17it/s]
Validating: 100%|██████████| 19/19 [00:01<00:00, 10.83it/s]


Epoch 24: Train Loss: 0.4431, Train Acc: 84.85% | Val Loss: 1.2126, Val Acc: 64.71%


Epoch 25/50 (Train): 100%|██████████| 64/64 [00:09<00:00,  7.06it/s]
Validating: 100%|██████████| 19/19 [00:01<00:00, 10.93it/s]


Epoch 25: Train Loss: 0.3961, Train Acc: 88.32% | Val Loss: 1.0409, Val Acc: 70.42%


Epoch 26/50 (Train): 100%|██████████| 64/64 [00:08<00:00,  7.25it/s]
Validating: 100%|██████████| 19/19 [00:01<00:00, 10.50it/s]


Epoch 26: Train Loss: 0.3978, Train Acc: 86.49% | Val Loss: 0.8402, Val Acc: 78.37%
Saved new best model (Acc: 78.37%)


Epoch 27/50 (Train): 100%|██████████| 64/64 [00:09<00:00,  6.72it/s]
Validating: 100%|██████████| 19/19 [00:01<00:00,  9.52it/s]


Epoch 27: Train Loss: 0.3484, Train Acc: 87.87% | Val Loss: 1.3456, Val Acc: 68.69%


Epoch 28/50 (Train): 100%|██████████| 64/64 [00:09<00:00,  7.10it/s]
Validating: 100%|██████████| 19/19 [00:01<00:00, 10.54it/s]


Epoch 28: Train Loss: 0.4395, Train Acc: 84.85% | Val Loss: 1.0541, Val Acc: 74.22%


Epoch 29/50 (Train): 100%|██████████| 64/64 [00:08<00:00,  7.34it/s]
Validating: 100%|██████████| 19/19 [00:01<00:00,  9.98it/s]


Epoch 29: Train Loss: 0.2589, Train Acc: 90.99% | Val Loss: 0.9467, Val Acc: 77.34%


Epoch 30/50 (Train): 100%|██████████| 64/64 [00:09<00:00,  7.09it/s]
Validating: 100%|██████████| 19/19 [00:01<00:00,  9.69it/s]


Epoch 30: Train Loss: 0.3178, Train Acc: 89.41% | Val Loss: 0.9810, Val Acc: 76.30%


Epoch 31/50 (Train): 100%|██████████| 64/64 [00:09<00:00,  6.96it/s]
Validating: 100%|██████████| 19/19 [00:01<00:00, 10.10it/s]


Epoch 31: Train Loss: 0.3382, Train Acc: 89.36% | Val Loss: 1.0659, Val Acc: 73.18%


Epoch 32/50 (Train): 100%|██████████| 64/64 [00:09<00:00,  7.03it/s]
Validating: 100%|██████████| 19/19 [00:02<00:00,  8.68it/s]


Epoch 32: Train Loss: 0.4356, Train Acc: 87.13% | Val Loss: 1.3743, Val Acc: 62.46%


Epoch 33/50 (Train): 100%|██████████| 64/64 [00:08<00:00,  7.17it/s]
Validating: 100%|██████████| 19/19 [00:01<00:00, 10.35it/s]


Epoch 33: Train Loss: 0.6938, Train Acc: 81.44% | Val Loss: 1.1121, Val Acc: 71.45%


Epoch 34/50 (Train): 100%|██████████| 64/64 [00:08<00:00,  7.21it/s]
Validating: 100%|██████████| 19/19 [00:01<00:00, 10.70it/s]


Epoch 34: Train Loss: 0.3255, Train Acc: 88.86% | Val Loss: 1.1213, Val Acc: 75.78%


Epoch 35/50 (Train): 100%|██████████| 64/64 [00:09<00:00,  7.08it/s]
Validating: 100%|██████████| 19/19 [00:02<00:00,  9.09it/s]


Epoch 35: Train Loss: 0.3085, Train Acc: 90.15% | Val Loss: 1.1512, Val Acc: 73.36%


Epoch 36/50 (Train): 100%|██████████| 64/64 [00:09<00:00,  7.08it/s]
Validating: 100%|██████████| 19/19 [00:01<00:00, 10.43it/s]


Epoch 36: Train Loss: 0.4813, Train Acc: 85.40% | Val Loss: 0.8681, Val Acc: 78.37%


Epoch 37/50 (Train): 100%|██████████| 64/64 [00:09<00:00,  6.98it/s]
Validating: 100%|██████████| 19/19 [00:01<00:00,  9.70it/s]


Epoch 37: Train Loss: 0.2833, Train Acc: 90.25% | Val Loss: 1.0672, Val Acc: 76.82%


Epoch 38/50 (Train): 100%|██████████| 64/64 [00:09<00:00,  6.93it/s]
Validating: 100%|██████████| 19/19 [00:01<00:00,  9.86it/s]


Epoch 38: Train Loss: 0.2185, Train Acc: 92.13% | Val Loss: 1.1392, Val Acc: 75.95%


Epoch 39/50 (Train): 100%|██████████| 64/64 [00:09<00:00,  7.08it/s]
Validating: 100%|██████████| 19/19 [00:01<00:00, 10.69it/s]


Epoch 39: Train Loss: 0.2307, Train Acc: 92.23% | Val Loss: 0.9107, Val Acc: 72.32%


Epoch 40/50 (Train): 100%|██████████| 64/64 [00:08<00:00,  7.15it/s]
Validating: 100%|██████████| 19/19 [00:01<00:00, 10.66it/s]


Epoch 40: Train Loss: 0.2341, Train Acc: 92.87% | Val Loss: 0.9948, Val Acc: 76.47%


Epoch 41/50 (Train): 100%|██████████| 64/64 [00:09<00:00,  7.02it/s]
Validating: 100%|██████████| 19/19 [00:01<00:00, 10.31it/s]


Epoch 41: Train Loss: 0.2178, Train Acc: 91.98% | Val Loss: 0.8358, Val Acc: 80.45%
Saved new best model (Acc: 80.45%)


Epoch 42/50 (Train): 100%|██████████| 64/64 [00:09<00:00,  7.02it/s]
Validating: 100%|██████████| 19/19 [00:01<00:00, 10.61it/s]


Epoch 42: Train Loss: 0.2133, Train Acc: 93.42% | Val Loss: 0.9994, Val Acc: 76.99%


Epoch 43/50 (Train): 100%|██████████| 64/64 [00:09<00:00,  6.87it/s]
Validating: 100%|██████████| 19/19 [00:01<00:00, 10.40it/s]


Epoch 43: Train Loss: 0.2511, Train Acc: 91.19% | Val Loss: 1.2275, Val Acc: 72.84%


Epoch 44/50 (Train): 100%|██████████| 64/64 [00:08<00:00,  7.17it/s]
Validating: 100%|██████████| 19/19 [00:01<00:00, 10.29it/s]


Epoch 44: Train Loss: 0.3912, Train Acc: 87.82% | Val Loss: 1.0213, Val Acc: 77.85%


Epoch 45/50 (Train): 100%|██████████| 64/64 [00:09<00:00,  7.09it/s]
Validating: 100%|██████████| 19/19 [00:01<00:00, 10.64it/s]


Epoch 45: Train Loss: 0.2106, Train Acc: 93.32% | Val Loss: 1.1308, Val Acc: 75.43%


Epoch 46/50 (Train): 100%|██████████| 64/64 [00:08<00:00,  7.17it/s]
Validating: 100%|██████████| 19/19 [00:01<00:00, 10.38it/s]


Epoch 46: Train Loss: 0.2863, Train Acc: 90.64% | Val Loss: 1.0472, Val Acc: 75.26%


Epoch 47/50 (Train): 100%|██████████| 64/64 [00:09<00:00,  6.97it/s]
Validating: 100%|██████████| 19/19 [00:01<00:00,  9.82it/s]


Epoch 47: Train Loss: 0.2235, Train Acc: 93.12% | Val Loss: 0.9045, Val Acc: 79.93%


Epoch 48/50 (Train): 100%|██████████| 64/64 [00:09<00:00,  7.02it/s]
Validating: 100%|██████████| 19/19 [00:01<00:00, 10.44it/s]


Epoch 48: Train Loss: 0.1691, Train Acc: 93.96% | Val Loss: 1.1335, Val Acc: 75.78%


Epoch 49/50 (Train): 100%|██████████| 64/64 [00:09<00:00,  6.95it/s]
Validating: 100%|██████████| 19/19 [00:01<00:00, 10.14it/s]


Epoch 49: Train Loss: 0.1593, Train Acc: 94.36% | Val Loss: 1.2050, Val Acc: 73.18%


Epoch 50/50 (Train): 100%|██████████| 64/64 [00:09<00:00,  7.07it/s]
Validating: 100%|██████████| 19/19 [00:01<00:00,  9.68it/s]

Epoch 50: Train Loss: 0.1982, Train Acc: 93.66% | Val Loss: 1.1035, Val Acc: 76.99%



