<a href="https://colab.research.google.com/github/Gabrielms-1/vit-classification-colab/blob/main/fine_tuning_vit.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
pip install torch torchvision wandb seaborn matplotlib numpy --quiet

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m4.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m113.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m85.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m52.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m2.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m5.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m16.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m127.9/127.9 MB[0m [31m7.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [2]:
import argparse
import yaml
import os
import torch
from dataset import FolderBasedDataset, create_dataloader
from model import VisionTransformer
import wandb
from datetime import datetime
import io
from PIL import Image
import seaborn as sns
import matplotlib.pyplot as plt

def process_data(train_dataset, val_dataset, resize, batch_size):
    train_dataset = FolderBasedDataset(train_dataset, resize)
    val_dataset = FolderBasedDataset(val_dataset, resize)

    train_loader, val_loader = create_dataloader(train_dataset, val_dataset, batch_size)

    return train_loader, val_loader

def compute_metrics(confusion_matrix):
    precision = torch.diag(confusion_matrix) / (torch.sum(confusion_matrix, dim=1) + 1e-10)
    recall = torch.diag(confusion_matrix) / (torch.sum(confusion_matrix, dim=0) + 1e-10)
    f1_score = 2 * precision * recall / (precision + recall + 1e-10)

    return precision, recall, f1_score

def evaluate_model(model, val_loader, criterion, device):
    model.eval()

    val_loss = 0
    correct_predictions = 0
    total_samples = 0

    confusion_matrix = torch.zeros(2, 2)

    with torch.no_grad():
        for batch_idx, (data, target, _) in enumerate(val_loader):
            data, target = data.to(device), target.to(device)

            output = model(data)
            _, predicted = torch.max(output.detach(), 1)
            loss = criterion(output, target)

            val_loss += loss.item() * data.size(0)
            total_samples += target.size(0)
            correct_predictions += (predicted == target).sum().item()

            for p, t in zip(predicted, target):
                confusion_matrix[t.long(), p.long()] += 1

        average_loss = val_loss / len(val_loader.dataset)
        accuracy = correct_predictions / total_samples

    precision, recall, f1_score = compute_metrics(confusion_matrix)

    return average_loss, accuracy, precision, recall, f1_score, confusion_matrix


def train_model(model, total_epochs, optimizer, criterion, train_loader, val_loader, device):

    train_losses = []
    val_losses = []
    val_accuracies = []
    train_accuracies = []

    model.train()

    for epoch in range(total_epochs):
        epoch_loss = 0
        correct_predictions = 0
        total_samples = 0

        for batch_idx, (data, target, _) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)

            optimizer.zero_grad()

            output = model(data)
            _, predicted = torch.max(output.detach(), 1)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item() * data.size(0)

            correct_predictions += (predicted == target).sum().item()
            total_samples += target.size(0)

        epoch_loss /= len(train_loader.dataset)
        train_losses.append(epoch_loss)

        val_loss, val_accuracy, val_precision, val_recall, val_f1_score, confusion_matrix = evaluate_model(model, val_loader, criterion, device)

        val_losses.append(val_loss)
        val_accuracies.append(val_accuracy)

        epoch_accuracy = correct_predictions / total_samples
        train_accuracies.append(epoch_accuracy)

        wandb.log({
            "epoch": epoch+1,
            "train_loss": epoch_loss,
            "val_loss": val_loss,
            "train_accuracy": epoch_accuracy,
            "val_accuracy": val_accuracy,
            "precision": val_precision,
            "recall": val_recall,
            "f1_score": val_f1_score,
        })

        f1_score = torch.mean(val_f1_score)

        print("-" * 50)
        print(f"EPOCH: {epoch+1}")
        print(f"- train_loss: {epoch_loss:.4f} | train_accuracy: {epoch_accuracy:.4f}")
        print(f"- val_loss: {val_loss:.4f} | val_accuracy: {val_accuracy:.4f} | f1_score: {f1_score:.4f}")
        print(f"-" * 50)

    return train_losses, val_losses, val_accuracies, train_accuracies, confusion_matrix

def main(args, config):

    wandb.init(
        project="weeds-classification-vit",
        name=f"weeds-vit-model_{args['timestamp']}",
        config={
            "epochs": args["epochs"],
            "batch_size": config["TRAIN"]["batch_size"],
            "learning_rate": config["TRAIN"]["lr"],
            "d_model": config["MODEL"]["d_model"],
            "n_classes": config["MODEL"]["n_classes"],
            "img_size": config["MODEL"]["img_size"],
            "patch_size": config["MODEL"]["patch_size"],
            "n_channels": config["MODEL"]["n_channels"],
            "n_heads": config["MODEL"]["n_heads"],
            "n_layers": config["MODEL"]["n_layers"],

        },
    )

    #os.makedirs("config["LOCAL"]["check_point_dir"]", exist_ok=True)

    train_loader, val_loader = process_data(args["train_dir"], args["val_dir"], args["resize"], args["batch_size"])
    val_dataset = FolderBasedDataset(args["val_dir"], args["resize"])



    model = VisionTransformer(
        config["MODEL"]["d_model"],
        config["MODEL"]["n_classes"],
        config["MODEL"]["img_size"],
        config["MODEL"]["patch_size"],
        config["MODEL"]["n_channels"],
        config["MODEL"]["n_heads"],
        config["MODEL"]["n_layers"],
    )

    optimizer = torch.optim.Adam(model.parameters(), lr=config["TRAIN"]["lr"])
    criterion = torch.nn.CrossEntropyLoss()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    train_losses, val_losses, val_accuracies, train_accuracies, confusion_matrix = train_model(model, args['epochs'], optimizer, criterion, train_loader, val_loader, device)

    class_names = [str(val_dataset.int_to_label_map[i]) for i in range(confusion_matrix.shape[0])]

    cm_numpy = confusion_matrix.cpu().numpy()
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm_numpy, annot=True, fmt='.0f', cmap='Blues',
                xticklabels=class_names, yticklabels=class_names)
    plt.xlabel('Predicted Labels')
    plt.ylabel('True Labels')
    plt.title('Confusion Matrix')
    plt.tight_layout()

    buf = io.BytesIO()
    plt.savefig(buf, format='png')
    buf.seek(0)
    cm_image = Image.open(buf)

    wandb.log({"confusion_matrix_image": wandb.Image(cm_image)})

    wandb.finish()

    save_path = os.path.join("/content/", "model_final.pth")
    torch.save(model.state_dict(), save_path)
    print(f"Model saved to {save_path}")

    return




In [None]:
with open("/content/train.yaml", "r") as f:
  config = yaml.safe_load(f)

args = {
    "epochs": config["TRAIN"]["epochs"],
    "train_dir": "/content/drive/MyDrive/data/train",
    "val_dir": "/content/drive/MyDrive/data/val",
    "batch_size": config["TRAIN"]["batch_size"],
    "resize": config["MODEL"]["img_size"],
    "timestamp": datetime.now().strftime("%Y%m%d-%H-%M-%S")
}

main(args, config)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mgabrielms-1[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin




--------------------------------------------------
EPOCH: 1
- train_loss: 0.7143 | train_accuracy: 0.5698
- val_loss: 0.5832 | val_accuracy: 0.7687 | f1_score: 0.7676
--------------------------------------------------
--------------------------------------------------
EPOCH: 2
- train_loss: 0.6181 | train_accuracy: 0.6952
- val_loss: 0.5540 | val_accuracy: 0.7537 | f1_score: 0.7521
--------------------------------------------------
--------------------------------------------------
EPOCH: 3
- train_loss: 0.5925 | train_accuracy: 0.7000
- val_loss: 0.5525 | val_accuracy: 0.7164 | f1_score: 0.7100
--------------------------------------------------
--------------------------------------------------
EPOCH: 4
- train_loss: 0.6139 | train_accuracy: 0.6921
- val_loss: 0.5349 | val_accuracy: 0.7836 | f1_score: 0.7830
--------------------------------------------------
--------------------------------------------------
EPOCH: 5
- train_loss: 0.6174 | train_accuracy: 0.6730
- val_loss: 0.5486 | v