In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision import transforms, models, datasets
import os
import tarfile
import requests
from tqdm import tqdm
import time
import copy
import timm

In [4]:

def download_and_extract_archive(url, download_root, extract_root=None, filename=None):
    if extract_root is None: extract_root = download_root
    if not filename: filename = os.path.basename(url)
    archive_path = os.path.join(download_root, filename)
    if not os.path.exists(extract_root): os.makedirs(extract_root)
    if os.path.exists(os.path.join(extract_root, 'CUB_200_2011')):
        print("Dataset already extracted.")
        return
    if not os.path.exists(archive_path):
        print(f"Downloading {url} to {archive_path}")
        response = requests.get(url, stream=True)
        response.raise_for_status()
        total_size = int(response.headers.get('content-length', 0))
        block_size = 1024
        with open(archive_path, 'wb') as f, tqdm(desc=filename, total=total_size, unit='iB', unit_scale=True, unit_divisor=1024) as bar:
            for data in response.iter_content(block_size):
                bar.update(len(data))
                f.write(data)
        print("Download complete.")
    else:
        print("Archive file already exists.")
    print(f"Extracting {archive_path} to {extract_root}...")
    with tarfile.open(archive_path, 'r:gz') as tar:
        tar.extractall(path=extract_root)
    print("Extraction complete.")

data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1), # Added ColorJitter
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

DATA_DIR = 'data'
DATASET_URL = 'https://data.caltech.edu/records/65de6-vp158/files/CUB_200_2011.tgz'
download_and_extract_archive(DATASET_URL, DATA_DIR)
image_dir = os.path.join(DATA_DIR, 'CUB_200_2011', 'images')

full_dataset = datasets.ImageFolder(root=image_dir, transform=None)
class_names = full_dataset.classes
num_classes = len(class_names)

train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_subset, val_subset = random_split(full_dataset, [train_size, val_size])

class TransformedSubset(Dataset):
    def __init__(self, subset, transform=None):
        self.subset = subset
        self.transform = transform
    def __getitem__(self, index):
        x, y = self.subset[index]
        if self.transform:
            x = self.transform(x)
        return x, y
    def __len__(self):
        return len(self.subset)

train_dataset = TransformedSubset(train_subset, transform=data_transforms['train'])
val_dataset = TransformedSubset(val_subset, transform=data_transforms['val'])

dataloaders = {
    'train': DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2),
    'val': DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=2)
}
dataset_sizes = {'train': len(train_dataset), 'val': len(val_dataset)}

print("\nData preparation complete.")

Downloading https://data.caltech.edu/records/65de6-vp158/files/CUB_200_2011.tgz to data/CUB_200_2011.tgz


CUB_200_2011.tgz: 100%|██████████| 1.07G/1.07G [00:40<00:00, 28.5MiB/s]


Download complete.
Extracting data/CUB_200_2011.tgz to data...
Extraction complete.

Data preparation complete.


In [5]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

model = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=num_classes)
model = model.to(device)

print("\nModel setup complete with Vision Transformer.")

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-5, weight_decay=0.01)
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=25, eta_min=1e-6)

Using device: cuda:0


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]


Model setup complete with Vision Transformer.


In [8]:
def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
    since = time.time()
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    # FIX: Use the modern torch.amp.GradScaler syntax
    scaler = torch.amp.GradScaler('cuda')

    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')
        print('-' * 10)

        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            running_corrects = 0

            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()

                # FIX: Use the modern torch.amp.autocast syntax
                with torch.amp.autocast('cuda'):
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)

                _, preds = torch.max(outputs, 1)

                # Scale loss and perform backward pass only in training phase
                if phase == 'train':
                    scaler.scale(loss).backward()
                    scaler.step(optimizer)
                    scaler.update()

                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            if phase == 'train':
                scheduler.step()

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]

            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

        print()

    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Best val Acc: {best_acc:4f}')

    model.load_state_dict(best_model_wts)
    return model

In [9]:
print("--- Starting High-Performance Vision Transformer Training ---")
model = train_model(model, criterion, optimizer, scheduler, num_epochs=25)

--- Starting High-Performance Vision Transformer Training ---
Epoch 0/24
----------
train Loss: 3.6327 Acc: 0.2510
val Loss: 1.6841 Acc: 0.5954

Epoch 1/24
----------
train Loss: 1.5928 Acc: 0.6287
val Loss: 0.9078 Acc: 0.7744

Epoch 2/24
----------
train Loss: 1.1188 Acc: 0.7333
val Loss: 0.7006 Acc: 0.8304

Epoch 3/24
----------
train Loss: 0.9302 Acc: 0.7768
val Loss: 0.6181 Acc: 0.8448

Epoch 4/24
----------
train Loss: 0.8263 Acc: 0.8014
val Loss: 0.5674 Acc: 0.8499

Epoch 5/24
----------
train Loss: 0.7309 Acc: 0.8245
val Loss: 0.5400 Acc: 0.8550

Epoch 6/24
----------
train Loss: 0.6923 Acc: 0.8339
val Loss: 0.5255 Acc: 0.8596

Epoch 7/24
----------
train Loss: 0.6357 Acc: 0.8468
val Loss: 0.5077 Acc: 0.8668

Epoch 8/24
----------
train Loss: 0.5857 Acc: 0.8622
val Loss: 0.5045 Acc: 0.8660

Epoch 9/24
----------
train Loss: 0.5397 Acc: 0.8701
val Loss: 0.4893 Acc: 0.8723

Epoch 10/24
----------
train Loss: 0.5247 Acc: 0.8742
val Loss: 0.4950 Acc: 0.8651

Epoch 11/24
----------
t

In [10]:
MODEL_V2_PATH = 'species_identifier_model_v2.pth'
torch.save(model.state_dict(), MODEL_V2_PATH)

print(f"\nHigh-performance model saved to {MODEL_V2_PATH}")


High-performance model saved to species_identifier_model_v2.pth
