In [4]:

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import argparse
import os
import sys
import time
import numpy as np
import random
from pathlib import Path
import torch.backends.cudnn as cudnn
from torch.utils.tensorboard import SummaryWriter
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torchvision.transforms import Normalize

from torchgeo.models.vit import ViTLarge16_Weights, vit_large_patch16_224
from datasets.EuroSat.eurosat_dataset import EurosatDataset,Subset
from sklearn.model_selection import train_test_split
from cvtorchvision import cvtransforms
from sklearn.metrics import average_precision_score

from torchgeo.datasets.eurosat import EuroSAT

In [None]:
def get_args_parser():
    parser = argparse.ArgumentParser('MAE linear probing for image classification', add_help=False)
    parser.add_argument('--batch_size', default=512, type=int,
                        help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus')
    parser.add_argument('--epochs', default=90, type=int)
    parser.add_argument('--accum_iter', default=1, type=int,
                        help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)')

    # Model parameters
    parser.add_argument('--model', default='vit_large_patch16', type=str, metavar='MODEL',
                        help='Name of model to train')

    # Optimizer parameters
    parser.add_argument('--weight_decay', type=float, default=0,
                        help='weight decay (default: 0 for linear probe following MoCo v1)')

    parser.add_argument('--lr', type=float, default=None, metavar='LR',
                        help='learning rate (absolute lr)')

    # * Finetuning params
    parser.add_argument('--cls_token', action='store_false', dest='global_pool',
                        help='Use class token instead of global pool for classification')

    # Dataset parameters
    parser.add_argument('--data_path', default='/datasets01/imagenet_full_size/061417/', type=str,
                        help='dataset path')
    parser.add_argument('--nb_classes', default=10, type=int,
                        help='number of the classification types')

    parser.add_argument('--output_dir', default='./output_dir',
                        help='path where to save, empty for no saving')
    parser.add_argument('--log_dir', default='./output_dir',
                        help='path where to tensorboard log')
    parser.add_argument('--device', default='cuda',
                        help='device to use for training / testing')
    parser.add_argument('--seed', default=0, type=int)
    parser.add_argument('--resume', default='',
                        help='resume from checkpoint')

    parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
                        help='start epoch')
    parser.add_argument('--eval', action='store_true',
                        help='Perform evaluation only')
    parser.add_argument('--num_workers', default=10, type=int)
    parser.add_argument('--pin_mem', action='store_true',
                        help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
    parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
    parser.set_defaults(pin_mem=True)

    
    return parser


In [None]:
BAND_STATS = {
    'mean': [
        1353.72696296, #'B01'
        1117.20222222, #'B02' 
        1041.8842963,  #'B03' 
        946.554,       #'B04'
        1199.18896296, #'B05' 
        2003.00696296, #'B06' 
        2374.00874074, #'B07' 
        2301.22014815, #'B08' 
        732.18207407,  #'B09' 
        12.09952894,   #'B10' 
        1820.69659259, #'B11' 
        1118.20259259,  #'B12' 
        2599.78311111, #'B8A'
    ],
    'std': [         
         897.27143653,  #'B01'
         736.01759721,  #'B02'
         684.77615743,  #'B03'
         620.02902871,  #'B04'
         791.86263829,  #'B05'
         1341.28018273, #'B06'
         1595.39989386, #'B07'
         1545.52915718, #'B08'
         475.11595216,  #'B09'
         98.26600935,   #'B10'
         1216.48651476, #'B11'
         736.6981037,    #'B12'
         1750.12066835, #'B8A'
    ]
}

def main(args):
    # Set up device
    device = torch.device(args.device)
    if args.seed is not None:
        random.seed(args.seed)
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.benchmark = True # note- comment in case of error for mps 

    # Create output directory
    Path(args.output_dir).mkdir(parents=True, exist_ok=True)

    # Initialize model
    model = vit_large_patch16_224(weights=ViTLarge16_Weights.SENTINEL2_ALL_MAE)
    model.head = nn.Linear(model.head.in_features, args.nb_classes)

    model.head = torch.nn.Sequential(torch.nn.BatchNorm1d(model.head.in_features, affine=False, eps=1e-6), model.head)

    # Initialize head parameters (e.g., Kaiming normal for weights, zeros for bias)
    for m in model.head.modules():
        if isinstance(m, nn.Linear):
            nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            if m.bias is not None:
                nn.init.zeros_(m.bias)

    # freeze all but the head
    for _, p in model.named_parameters():
        p.requires_grad = False
    for _, p in model.head.named_parameters():
        p.requires_grad = True

    model.to(device)

    # Define loss function and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=args.weight_decay)

   

    # Load dataset
    transform_train = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=BAND_STATS['mean'], std=BAND_STATS['std']),
    ])
    
    transform_val = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=BAND_STATS['mean'], std=BAND_STATS['std']),
    ])

    # Load EuroSAT dataset

    train_dataset = EuroSAT(root=args.data_path, split='train', transform=transform_train)
    val_dataset = EuroSAT(root=args.data_path, split='val', transform=transform_val)
    test_dataset = EuroSAT(root=args.data_path, split='test', transform=transform_val)

    train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True,
                                   num_workers=args.num_workers, pin_memory=args.pin_mem)
    val_loader = data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False,
                                 num_workers=args.num_workers, pin_memory=args.pin_mem)
    test_loader = data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False,
                                  num_workers=args.num_workers, pin_memory=args.pin_mem)

    # Training loop
    for epoch in range(args.start_epoch, args.epochs):
        model.train()
        for images, labels in train_loader:
            images = images.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

        print(f'Epoch [{epoch+1}/{args.epochs}], Loss: {loss.item():.4f}')

        # Validation step
        model.eval()
        with