In [45]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision.transforms import transforms
from tqdm import tqdm

In [46]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [47]:
# Define the CNN architecture
class PlantVillageCNN(nn.Module):
    def __init__(self):
        super(PlantVillageCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(32, 16, 3, padding=1)
        self.pool2 = nn.MaxPool2d(2, 2)
        self.conv3 = nn.Conv2d(16, 8, 3, padding=1)
        self.pool3 = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(8 * 16 * 16, 128)
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.pool1(nn.functional.relu(self.conv1(x)))
        x = self.pool2(nn.functional.relu(self.conv2(x)))
        x = self.pool3(nn.functional.relu(self.conv3(x)))
        x = x.view(-1, 8 * 16 * 16)
        x = nn.functional.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

In [48]:
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(128),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(128),
        transforms.CenterCrop(128),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
}

In [49]:
# Load the datasets
train_dataset = ImageFolder('../Datasets/train', transform=data_transforms['train'])
val_dataset = ImageFolder('../Datasets/val', transform=data_transforms['val'])


In [50]:

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)


In [51]:
model = PlantVillageCNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())


In [52]:
# Train the model
num_epochs = 2
for epoch in range(num_epochs):
    model.train()  # Set model to training mode
    running_loss = 0.0
    
    # Training loop
    for inputs, labels in tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}'):
        # Move data to device
        inputs = inputs.to(device)
        labels = labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    # Validation - now done every epoch instead of every 100 epochs
    model.eval()
    val_loss = 0.0
    val_correct = 0
    val_total = 0
    
    with torch.no_grad():
        for inputs, labels in tqdm(val_loader, desc='Validating'):
            # Move data to device
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            val_total += labels.size(0)
            val_correct += (predicted == labels).sum().item()
            val_loss += criterion(outputs, labels).item()

    val_accuracy = val_correct / val_total
    
    # Print metrics
    print(f'Epoch [{epoch+1}/{num_epochs}]')
    print(f'Training Loss: {running_loss/len(train_loader):.4f}')
    print(f'Validation Loss: {val_loss/len(val_loader):.4f}')
    print(f'Validation Accuracy: {val_accuracy:.4f}')
    print('-' * 50)

print('Training completed!')

Epoch 1/2: 100%|██████████| 157/157 [01:13<00:00,  2.15it/s]
Validating: 100%|██████████| 16/16 [00:03<00:00,  5.29it/s]


Epoch [1/2]
Training Loss: 1.8428
Validation Loss: 1.3756
Validation Accuracy: 0.5600
--------------------------------------------------


Epoch 2/2: 100%|██████████| 157/157 [00:28<00:00,  5.59it/s]
Validating: 100%|██████████| 16/16 [00:01<00:00,  8.34it/s]

Epoch [2/2]
Training Loss: 1.3435
Validation Loss: 1.0020
Validation Accuracy: 0.6680
--------------------------------------------------
Training completed!





Keeping a copy of the original weights (cloning them) so later we can prove that a fine-tuning with LoRA doesn't alter the original weights

In [53]:
original_weights = {}
for name, param in model.named_parameters():
    original_weights[name] = param.clone().detach()

The the performance of the pretrained network. As we can see, the network performs poorly on the digit 9. Let's fine-tune it on the digit 9

In [55]:
def test():
    # First, make sure model is on the correct device
    model.eval()
    
    correct = 0
    total = 0
    wrong_counts = [0 for i in range(10)]  # Adjust number based on your classes

    with torch.no_grad():
        for data in tqdm(val_loader, desc='Testing'):
            x, y = data
            # Move both inputs and labels to the same device as model
            x = x.to(device)  # Should be shape [batch_size, 3, 256, 256]
            y = y.to(device)
            
            outputs = model(x)
            
            for idx, output in enumerate(outputs):
                predicted = torch.argmax(output)
                if predicted == y[idx]:
                    correct += 1
                else:
                    wrong_counts[y[idx].item()] += 1
                total += 1

    accuracy = round(correct/total, 3)
    print(f'Overall Accuracy: {accuracy}')
    
    # Print wrong predictions for each class
    for i in range(len(wrong_counts)):
        print(f'Wrong predictions for class {i}: {wrong_counts[i]}')
        # Calculate per-class accuracy
        class_total = sum(1 for _, label in val_loader.dataset if label == i)
        class_accuracy = round((class_total - wrong_counts[i]) / class_total, 3)
        print(f'Class {i} Accuracy: {class_accuracy}')

test()

Testing: 100%|██████████| 16/16 [00:02<00:00,  6.58it/s]


Overall Accuracy: 0.668
Wrong predictions for class 0: 29
Class 0 Accuracy: 0.71
Wrong predictions for class 1: 43
Class 1 Accuracy: 0.57
Wrong predictions for class 2: 27
Class 2 Accuracy: 0.73
Wrong predictions for class 3: 52
Class 3 Accuracy: 0.48
Wrong predictions for class 4: 34
Class 4 Accuracy: 0.66
Wrong predictions for class 5: 29
Class 5 Accuracy: 0.71
Wrong predictions for class 6: 62
Class 6 Accuracy: 0.38
Wrong predictions for class 7: 13
Class 7 Accuracy: 0.87
Wrong predictions for class 8: 23
Class 8 Accuracy: 0.77
Wrong predictions for class 9: 20
Class 9 Accuracy: 0.8


Let's visualize how many parameters are in the original network, before introducing the LoRA matrices.

In [None]:
# Initialize the total parameters counter
total_parameters_original = 0

# Iterate over all layers in the model
for index, layer in enumerate(model.children()):
    if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear):
        total_parameters_original += layer.weight.nelement()
        if layer.bias is not None:
            total_parameters_original += layer.bias.nelement()
        print(f'Layer {index+1}: W: {layer.weight.shape} + B: {layer.bias.shape if layer.bias is not None else "N/A"}')

print(f'Total number of parameters: {total_parameters_original:,}')


In [None]:
# # Save the trained modeln 
# torch.save(model.state_dict(), 'sequential_model.pth')
# print('Saved trained model as sequential_model.pth')

Saved trained model as sequential_model.pth
