In [1]:
# ==========================================
# 1. SETUP & IMPORTS
# ==========================================
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader, random_split
import os
import zipfile
import matplotlib.pyplot as plt
from sklearn.metrics import classification_report, confusion_matrix
import numpy as np
from google.colab import files
import shutil
import pandas as pd
import zipfile
import tensorflow_datasets as tfds
from PIL import Image

# Device configuration (Use GPU if available)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cuda


In [2]:
# ==========================================
# 2. DATA DOWNLOAD & CONVERSION
# ==========================================
# I use TFDS to download because it's Google's most stable mirror for Colab
print("Downloading DeepWeeds via stable mirror...")
ds_builder = tfds.builder("deep_weeds")
ds_builder.download_and_prepare()

# Create organized folders for PyTorch
base_dir = "deepweeds_pytorch"
os.makedirs(base_dir, exist_ok=True)

print("Converting dataset to PyTorch-friendly format...")
info = ds_builder.info
class_names = info.features["label"].names

# Extract images from the downloaded archives into class folders
for split in ["train"]:
    ds = tfds.as_numpy(ds_builder.as_dataset(split=split))
    for i, example in enumerate(ds):
        image = Image.fromarray(example["image"])
        label_idx = example["label"]
        label_name = class_names[label_idx]

        class_folder = os.path.join(base_dir, label_name)
        os.makedirs(class_folder, exist_ok=True)

        image.save(os.path.join(class_folder, f"img_{i}.jpg"))
        if i % 1000 == 0: print(f"Processed {i} images...")

print("✅ Dataset successfully organized for PyTorch!")

Downloading DeepWeeds via stable mirror...
Downloading and preparing dataset Unknown size (download: Unknown size, generated: Unknown size, total: Unknown size) to /root/tensorflow_datasets/deep_weeds/3.0.0...


Dl Completed...: 0 url [00:00, ? url/s]

Dl Size...: 0 MiB [00:00, ? MiB/s]

Extraction completed...: 0 file [00:00, ? file/s]

Generating splits...:   0%|          | 0/1 [00:00<?, ? splits/s]

Generating train examples...: 0 examples [00:00, ? examples/s]

Shuffling /root/tensorflow_datasets/deep_weeds/incomplete.O8C48O_3.0.0/deep_weeds-train.tfrecord*...:   0%|   …

Dataset deep_weeds downloaded and prepared to /root/tensorflow_datasets/deep_weeds/3.0.0. Subsequent calls will reuse this data.
Converting dataset to PyTorch-friendly format...
Processed 0 images...
Processed 1000 images...
Processed 2000 images...
Processed 3000 images...
Processed 4000 images...
Processed 5000 images...
Processed 6000 images...
Processed 7000 images...
Processed 8000 images...
Processed 9000 images...
Processed 10000 images...
Processed 11000 images...
Processed 12000 images...
Processed 13000 images...
Processed 14000 images...
Processed 15000 images...
Processed 16000 images...
Processed 17000 images...
✅ Dataset successfully organized for PyTorch!


In [3]:
# ==========================================
# 3. DATA PREPROCESSING
# ==========================================
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

data_transforms = {
    'train': transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

full_dataset = datasets.ImageFolder(base_dir, transform=data_transforms['train'])
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])
val_dataset.dataset.transform = data_transforms['val']

dataloaders = {
    'train': DataLoader(train_dataset, batch_size=32, shuffle=True),
    'val': DataLoader(val_dataset, batch_size=32, shuffle=False)
}


In [4]:
# ==========================================
# 4. MODEL: RESNET50 (Partial Fine Tuning)
# ==========================================
model = models.resnet50(weights='IMAGENET1K_V1')

for param in model.parameters():
    param.requires_grad = False

# Replace final layer for the 9 DeepWeeds classes
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, len(class_names))
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.fc.parameters(), lr=0.001)

Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth


100%|██████████| 97.8M/97.8M [00:00<00:00, 135MB/s]


In [5]:
# ==========================================
# 5. TRAINING & METRICS
# ==========================================
def train(epochs=5):
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        for inputs, labels in dataloaders['train']:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        model.eval()
        correct = 0
        with torch.no_grad():
            for inputs, labels in dataloaders['val']:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                _, preds = torch.max(outputs, 1)
                correct += torch.sum(preds == labels.data)

        print(f"Epoch {epoch+1} | Loss: {running_loss/len(dataloaders['train']):.4f} | Val Acc: {correct.double()/val_size:.4f}")

train(epochs=5)

Epoch 1 | Loss: 0.9684 | Val Acc: 0.7407
Epoch 2 | Loss: 0.7233 | Val Acc: 0.7641
Epoch 3 | Loss: 0.6663 | Val Acc: 0.7730
Epoch 4 | Loss: 0.6266 | Val Acc: 0.7895
Epoch 5 | Loss: 0.6087 | Val Acc: 0.7964


In [6]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix

def evaluate_model_metrics(dataloader):
    model.eval()
    all_preds = []
    all_labels = []

    print("Calculating final metrics on validation set...")
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    # Calculate individual metrics
    # 'weighted' accounts for label imbalance in the dataset
    accuracy = accuracy_score(all_labels, all_preds)
    precision = precision_score(all_labels, all_preds, average='weighted')
    recall = recall_score(all_labels, all_preds, average='weighted')
    f1 = f1_score(all_labels, all_preds, average='weighted')
    cm = confusion_matrix(all_labels, all_preds)


    print("\n" + "="*40)
    print("Test Set Performance:")
    print("="*40)
    print(f"Accuracy  : {accuracy:.4f}")
    print(f"Precision : {precision:.4f}")
    print(f"Recall    : {recall:.4f}")
    print(f"F1-score  : {f1:.4f}")

    print("\nConfusion Matrix:")
    print(cm)

    return accuracy, precision, recall, f1, cm


metrics = evaluate_model_metrics(dataloaders['val'])

Calculating final metrics on validation set...

Test Set Performance:
Accuracy  : 0.7964
Precision : 0.8043
Recall    : 0.7964
F1-score  : 0.7963

Confusion Matrix:
[[ 128    5   35    1   14    1    3    0   26]
 [   8  158   52    0    5    0    9    2   12]
 [  29   17 1599   12   61   23   29   14   33]
 [   0    0   19  175    7    5    0    0    0]
 [   1    0   28    3  162    3    2    0    0]
 [   0    0   22    6   19  144    0    0    1]
 [   2    1   40    3   10    0  158    1    5]
 [   1    6   64    1    2    0    4  151    3]
 [  19    1   31    0   10    0    1    1  114]]
