# Load the data

In [1]:
import gc
from src.training.pretrainedModels import get_pretrained_model
from src.colors import bcolors
from config import Config

c = bcolors()
config = Config()
PRETRAINED_MODEL = "vit_b_16"

In [2]:
from sklearn.preprocessing import OneHotEncoder
from torchvision import transforms
import pandas as pd

from src.training.pretrainedModels import get_pretrained_model
from src.pickle_loader import save_object

CHANNELS_2C = [3, 2, 1]
CHANNELS_2A = [3, 2, 1]
NUM_CLASSES = 10
NUM_AUG = 1

_, transform = get_pretrained_model(PRETRAINED_MODEL)

df = pd.read_csv(config.TRAIN_FILE)

df_test = pd.read_csv("labels.csv")

encoder = OneHotEncoder()
encoder = encoder.fit(df[['label']].values.reshape(-1, 1))
save_object(encoder, config.DATA_DIR + "on_hot_encoder")
print(transform)

Compose(
    Resize(output_size=256, p=1.0, p_batch=1.0, same_on_batch=True, size=256, side=short, resample=bilinear, align_corners=True, antialias=False)
    CenterCrop(p=1.0, p_batch=1.0, same_on_batch=True, resample=bilinear, cropping_mode=slice, align_corners=True, size=(224, 224), padding_mode=zeros)
    Normalize(p=1.0, p_batch=1.0, same_on_batch=True, mean=tensor([0.4850, 0.4560, 0.4060]), std=tensor([0.2290, 0.2240, 0.2250]))
)


In [3]:
mean_std_S2C = {
    'mean': [1373.1, 1322.3, 1397.6],
    'std': [1144.9, 878.7, 854.3]
}

In [4]:
from sklearn.model_selection import train_test_split
from src.datasets.EuroSatMS import EuroSatMS

train_df, val_df = train_test_split(df, test_size=0.04, stratify=df['label'])
print(df['label'].unique())

ds_train = EuroSatMS(
    train_df, 
    config.TRAIN_MS_DIR,
    encoder=encoder,
    num_aug=NUM_AUG, 
    select_chan=CHANNELS_2C,
    transform=transform,
    mean_std=mean_std_S2C
)

ds_val = EuroSatMS(
    df_test, 
    config.TEST_MS_DIR,
    encoder=encoder,
    num_aug=NUM_AUG, 
    select_chan=CHANNELS_2A,
    transform=transform,
    mean_std=mean_std_S2C
)

ds_test = EuroSatMS(
    val_df, 
    config.TRAIN_MS_DIR,
    encoder=encoder,
    num_aug=NUM_AUG, 
    select_chan=CHANNELS_2C,
    transform=transform,
    mean_std=mean_std_S2C
)

print(f"""\n{c.OKGREEN}Train dataset:      {len(ds_train)} samples{c.ENDC}""")
print(f"""{c.OKGREEN}Validation dataset: {len(ds_val)} samples{c.ENDC}""")
print(f"""{c.OKGREEN}Test dataset:       {len(ds_test)} samples{c.ENDC}""")

['AnnualCrop' 'Forest' 'HerbaceousVegetation' 'Highway' 'Industrial'
 'Pasture' 'PermanentCrop' 'Residential' 'River' 'SeaLake']

[92mPreloading images...[0m

[96mImages:         25920[0m
[96mAugmentations:  25920[0m
[96mJobs:           -4 [0m

[94mTime taken:      0 min 22.490834951400757 sec [0m

[92mPreloading images...[0m

[96mImages:         1003[0m
[96mAugmentations:  1003[0m
[96mJobs:           -4 [0m

[94mTime taken:      0 min 0.46628642082214355 sec [0m

[92mPreloading images...[0m

[96mImages:         1080[0m
[96mAugmentations:  1080[0m
[96mJobs:           -4 [0m

[94mTime taken:      0 min 0.8866140842437744 sec [0m

[92mTrain dataset:      25920 samples[0m
[92mValidation dataset: 1003 samples[0m
[92mTest dataset:       1080 samples[0m


In [5]:
import numpy as np

ii = ds_test[0][0]

print(ii[0])
print(ds_train.process_image(0)[0])

print(ii.shape)
print(np.mean(ii.numpy(), axis=(1, 2)))

tensor([[-1.6035, -1.5611, -1.4862,  ..., -0.6051, -0.4419, -0.2927],
        [-1.5299, -1.4805, -1.3986,  ..., -0.4861, -0.3074, -0.1430],
        [-1.4323, -1.3703, -1.2763,  ..., -0.3492, -0.1565,  0.0199],
        ...,
        [-0.8893, -0.9173, -0.9587,  ..., -1.5494, -1.5523, -1.5631],
        [-0.9660, -1.0002, -1.0487,  ..., -1.4931, -1.4859, -1.4898],
        [-1.0228, -1.0610, -1.1132,  ..., -1.4530, -1.4367, -1.4339]])
[[[0.53968257 0.53968257 0.6825397  ... 0.46825397 0.46825397 0.34126985]
  [0.53968257 0.53968257 0.6825397  ... 0.46825397 0.46825397 0.34126985]
  [0.5        0.5        0.50793654 ... 0.52380955 0.61904764 0.3809524 ]
  ...
  [0.         0.         0.16666667 ... 0.8492063  0.8095238  0.76984125]
  [0.06349207 0.06349207 0.11111111 ... 0.7936508  0.6984127  0.71428573]
  [0.06349207 0.06349207 0.16666667 ... 0.76984125 0.73015875 0.8333333 ]]

 [[0.63       0.63       0.58       ... 0.43       0.45       0.37      ]
  [0.63       0.63       0.58       ... 

In [6]:
from kornia.constants import Resample
import kornia.augmentation as K
from torch import nn

p1, p2, p3 = 0.6, 0.75, 0.4
augmentation = nn.Sequential(
    K.RandomHorizontalFlip(p=p1),
    K.RandomVerticalFlip(p=p1),
    K.RandomAffine(degrees=30, translate=None, scale=None, shear=None, resample="nearest", padding_mode=2, p=p2),
    K.RandomShear(shear=0.2, resample="nearest", padding_mode=2, p=p2),
    K.RandomBrightness((0.5, 1.5), p=p2),
    # K.RandomContrast(contrast=(0.85, 1.15), p=p2),
    K.RandomSaturation((0.8, 1.2), p=p3),
    # K.RandomPlasmaContrast(roughness=(0.01, 0.15), p=p3),
    # K.RandomSolarize(thresholds=0.1, p=p3),
    # K.RandomSharpness(sharpness=(0.1, 0.3), p=p3),
    # K.RandomBoxBlur(kernel_size=(3, 3), p=p1),
    # K.RandomEqualize(p=p1),
    K.CenterCrop(size=(64, 64)),
)

ds_train.augment = augmentation
ds_val.augment = None
ds_test.augment = None

# CNN Model Training

In [None]:
import numpy as np
from src.training.data import EuroSatDataModule

BATCH_SIZE = 32
data_module = EuroSatDataModule(ds_train, ds_val, ds_test, BATCH_SIZE)
print(f"""{c.OKGREEN}Initialized the data module...{c.ENDC}""")

KERNEL_SIZE = [5, 3]
LEARNING_RATE = 0.03
MOMENTUM = 0.9
GAMMA = 0.9
DROPOUT = 0.3
EPOCHS = 15
CKPT_PATH = 'checkpoints/cnn/'
CLASS_WEIGHTS = {'AnnualCrop': 9,
                 'Forest': 9,
                 'HerbaceousVegetation': 8, 
                 'Highway': 9, 
                 'Industrial': 9,
                 'Pasture': 9, 
                 'PermanentCrop': 9, 
                 'Residential': 9, 
                 'River': 9, 
                 'SeaLake': 9}

w = np.array(list(CLASS_WEIGHTS.values()))
w_min, w_max = w.min(), w.max()
w = (w - w_min) / (w_max - w_min)

fname = "cnn_c"
for c in CHANNELS_2C:
    fname += str(c)
fname += "_k"
for k in KERNEL_SIZE:
    fname += str(k)
fname += "_lr" + str(LEARNING_RATE)
fname += "_m" + str(MOMENTUM)
fname += "_g" + str(GAMMA)
fname += "_d" + str(DROPOUT)

print("Model name: ", fname)

In [None]:
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from datetime import datetime

from src.training.cnn import LitEuroSatCnn

lightning_model = LitEuroSatCnn(
    num_classes=NUM_CLASSES,
    learning_rate=LEARNING_RATE, 
    num_channels=len(CHANNELS_2C), 
    kernel_size=KERNEL_SIZE,
    momentum=MOMENTUM,
    gamma=GAMMA,
    weights=torch.tensor(w, dtype=torch.float32),
    dropout=DROPOUT
)

logger = WandbLogger(
    project="eurosat_cnn",
    name="cnn_v1",
    log_model=False,
)

checkpoint_callback = ModelCheckpoint(
    dirpath=CKPT_PATH + datetime.now().strftime("%H-%M"),
    filename=fname + '_{epoch:02d}-{val_loss:.2f}',
    save_top_k=2, 
    monitor="val_loss",
    verbose=False
)

trainer = Trainer(
    max_epochs=EPOCHS,
    accelerator="gpu", 
    devices=1,
    logger=logger,
    callbacks=[checkpoint_callback],
)

trainer.fit(lightning_model, datamodule=data_module)

In [None]:
import torchmetrics
import matplotlib.pyplot as plt

trainer.test(lightning_model, datamodule=data_module, ckpt_path=checkpoint_callback.best_model_path)
metric = torchmetrics.ConfusionMatrix(task="multiclass", num_classes=NUM_CLASSES)
all_preds = np.concatenate(lightning_model.ep_out)
all_true = np.concatenate(lightning_model.ep_true)
true_ep = torch.tensor(all_true)
pred_ep = torch.tensor(all_preds)
metric.update(pred_ep, true_ep)
fig, ax = metric.plot()
plt.show()
print(encoder.categories_[0])

# Pretrained Model Training
### Training Phase 1
##### Pretrained ResNet50, ResNet18, AlexNet, ViT Model

In [8]:
import torch
from tabulate import tabulate
from src.training.pretrainedModels import get_pretrained_model
from src.training.data import EuroSatDataModule
from src.training.pretrainedModels import EuroSatPreTrainedModel

BATCH_SIZE = 128
CKPT_PATH = f'checkpoints/{PRETRAINED_MODEL}/'

model, _ = get_pretrained_model(PRETRAINED_MODEL)

data_module = EuroSatDataModule(ds_train, ds_val, ds_test, BATCH_SIZE)

model_train = EuroSatPreTrainedModel(
    backbone=model,
    learning_rate=0.0005,
    layers=[256],
    opt="sgd",
    gamma=0.9,
    momentum=0.9,
    dropout=0.3,
    weight_decay=0.001
)
print(model_train)

EuroSatPreTrainedModel(
  (backbone): VisionTransformer(
    (conv_proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    (encoder): Encoder(
      (dropout): Dropout(p=0.0, inplace=False)
      (layers): Sequential(
        (encoder_layer_0): EncoderBlock(
          (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (self_attention): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
          )
          (dropout): Dropout(p=0.0, inplace=False)
          (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (mlp): MLPBlock(
            (0): Linear(in_features=768, out_features=3072, bias=True)
            (1): GELU(approximate='none')
            (2): Dropout(p=0.0, inplace=False)
            (3): Linear(in_features=3072, out_features=768, bias=True)
            (4): Dropout(p=0.0, inplace=False)
          )
        )
        (encoder_layer_1): EncoderBlock(
       

In [9]:
from datetime import datetime
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning import Trainer

logger = WandbLogger(
    project="eurosat_" + PRETRAINED_MODEL.split("_")[0],
    name=PRETRAINED_MODEL,
    log_model=False,
)

checkpoint_callback = ModelCheckpoint(
    dirpath=CKPT_PATH + datetime.now().strftime("%H-%M"),
    filename='{epoch:02d}-{val_loss:.2f}-{train_loss_epoch:.2f}',
    save_top_k=5, 
    monitor="val_loss",
    verbose=False
)

trainer = Trainer(
    max_epochs=5,
    accelerator="gpu", 
    devices=1,
    logger=logger,
    callbacks=[checkpoint_callback],
)

trainer.fit(model_train, datamodule=data_module)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
[34m[1mwandb[0m: Currently logged in as: [33mthe-virus[0m. Use [1m`wandb login --relogin`[0m to force relogin


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name       | Type              | Params
-------------------------------------------------
0 | backbone   | VisionTransformer | 85.8 M
1 | criterion  | CrossEntropyLoss  | 0     
2 | classifier | Sequential        | 199 K 
-------------------------------------------------
199 K     Trainable params
85.8 M    Non-trainable params
86.0 M    Total params
343.992   Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=5` reached.


In [None]:
from tabulate import tabulate
from datetime import datetime
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning import Trainer

logger = WandbLogger(
    project="eurosat_" + PRETRAINED_MODEL.split("_")[0],
    name=PRETRAINED_MODEL,
    log_model=False,
)
model, _ = get_pretrained_model(PRETRAINED_MODEL)

model_eval = EuroSatPreTrainedModel.load_from_checkpoint(
    "checkpoints/resnet50_RGB_MOCO/12-15/epoch=02-val_loss=1.73.ckpt", #checkpoint_callback.best_model_path,
    backbone=model,
    learning_rate=0.001,
    layers=[],
    gamma=0.95,
    momentum=0.9,
    dropout=0.3,
    weight_decay=1e-5
)

for param in model_eval.backbone.layer4.parameters():
    param.requires_grad = True

# for param in model_eval.backbone.layer3[-3:].parameters():
#     param.requires_grad = True

for param in model_eval.backbone.fc.parameters():
    param.requires_grad = True

tabel = []
for name, param in model_eval.backbone.named_parameters():
    num_params = f"{param.numel() // 1000}k"
    tabel.append([f"{c.OKGREEN}{name}{c.ENDC}", f"{c.OKBLUE}{param.requires_grad}{c.ENDC}", num_params])

print(tabulate(tabel, headers=[f"{c.OKGREEN}Layer{c.ENDC}", f"{c.OKBLUE}Trainable{c.ENDC}", f"Parameters"]))

checkpoint_callback = ModelCheckpoint(
    dirpath=CKPT_PATH + datetime.now().strftime("%H-%M"),
    filename='{epoch:02d}-{val_loss:.2f}',
    save_top_k=4, 
    monitor="val_loss",
    verbose=False
)

trainer = Trainer(
    max_epochs=10,
    accelerator="gpu", 
    devices=1,
    logger=logger,
    callbacks=[checkpoint_callback],
)

trainer.fit(model_eval, datamodule=data_module)

In [None]:
import torchmetrics
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

trainer.validate(model_train, datamodule=data_module, ckpt_path="checkpoints/vit_b_16/13-19/epoch=03-val_loss=1.44-train_loss_epoch=0.47.ckpt")#checkpoint_callback.best_model_path)

all_preds = np.concatenate(model_train.ep_out)
all_true = np.concatenate(model_train.ep_true)
true_ep = torch.tensor(all_true)
pred_ep = torch.tensor(all_preds)

metric = torchmetrics.ConfusionMatrix(task="multiclass", num_classes=NUM_CLASSES)
metric.update(pred_ep, true_ep)
confmat = metric.compute()
confmat_np = confmat.numpy()
tick_labels = [encoder.categories_[0][i] for i in range(NUM_CLASSES)]

plt.figure(figsize=(10, 8))
sns.heatmap(confmat_np, annot=True, fmt='g', cmap='Blues', 
            xticklabels=tick_labels, yticklabels=tick_labels)
plt.xlabel('Predicted labels')
plt.ylabel('True labels')
plt.title('Confusion Matrix')
plt.show()

In [None]:
from tabulate import tabulate
from src.training.pretrainedModels import EuroSatPreTrainedModel
from datetime import datetime
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import BackboneFinetuning, ModelCheckpoint
from pytorch_lightning import Trainer
from src.training.data import EuroSatDataModule
import random
import torch
import gc
import wandb

CKPT_PATH = f'checkpoints/{PRETRAINED_MODEL}/'

n_rounds = 20
gamma = 0.95
momentum = 0.9
dropout = 0.3

layers_samples = [[], [256]]

wdec = [0.01, 0.001, 0.0001]

optims = ["adam", "sgd"]

lr_samples = [0.01, 0.001, 0.0005, 0.0001]

bs_samples = [128]

param_space = [[l, b, wd, ly, opt] for l in lr_samples for b in bs_samples for wd in wdec for ly in layers_samples for opt in optims]
random.shuffle(param_space)

out_table = []

for i in range(n_rounds):
    lr_hp, bs_hp, wd, ly, opt = param_space[i]
    
    print(f"{c.OKBLUE}Round {i+1}/{n_rounds}{c.ENDC}")
    print(f"{c.OKBLUE}Learning Rate: {lr_hp:.4f}, Batch Size: {bs_hp}, Weight Decay: {wd}, Layers: {ly}, Optimizer: {opt}{c.ENDC}")
    
    data_module = EuroSatDataModule(ds_train, ds_val, ds_test, bs_hp)
    
    model, _ = get_pretrained_model(PRETRAINED_MODEL)
    model_train = EuroSatPreTrainedModel(
        backbone=model,
        layers=ly,
        opt=opt,
        learning_rate=lr_hp,
        gamma=gamma,
        momentum=momentum,
        dropout=dropout,
        weight_decay=wd
    )
    
    checkpoint_callback = ModelCheckpoint(
        dirpath=CKPT_PATH + "hp_tuning" + f"/round_{i}",
        filename=f"bs_{bs_hp}_lr_{lr_hp}_wd_{wd}_loss_" + '{val_loss:.2f}',
        save_top_k=1, 
        monitor="val_loss",
        verbose=False
    )

    wandb_logger = WandbLogger(
        project="eurosat_resnet",
        name=f"bs_{bs_hp}_lr_{round(lr_hp, 4)}_wd_{wd}",
        log_model=False,
    )

    trainer = Trainer(
        max_epochs=2,
        accelerator="gpu", 
        devices=1,
        logger=wandb_logger,
        callbacks=[checkpoint_callback],
    )

    trainer.fit(model_train, datamodule=data_module)
    trainer.validate(model_train, datamodule=data_module, ckpt_path=checkpoint_callback.best_model_path, verbose=False)
    score = model_train.accuracy
    out_table.append([lr_hp, bs_hp, wd, ly, opt, score])
    print(tabulate(out_table, headers=["Learning Rate", "Batch Size", "Weight Decay", "Layers", "Optimizer", "Accuracy"]))
    wandb.finish()
    data_module = None
    model_train = None
    gc.collect()
    torch.cuda.empty_cache()
    
print(tabulate(out_table, headers=["Learning Rate", "Batch Size", "Weight Decay", "Layers", "Optimizer", "Accuracy"]))

# Predict on the test set

In [None]:
from src.datasets.EuroSatTest import EuroSatTestSet
from torch.utils.data import DataLoader
from config import Config
from src.training.pretrainedModels import get_pretrained_model

config = Config()

_, transform = get_pretrained_model(PRETRAINED_MODEL)

mean_std_S2A = {
    'mean': [1307.6, 1151.7, 889.6],
    'std': [1375.2, 1188.1, 1159.1]
}

# transform = transforms.Compose([
#     K.CenterCrop(size=(56, 56)),
#     K.RandomHorizontalFlip(),
# ])

dataset = EuroSatTestSet(config.TEST_MS_DIR, select_chan=CHANNELS_2A, add_B10=False, mean_std=mean_std_S2A, transform=transform) #, augment=augmentation)
dataloader = DataLoader(dataset, batch_size=256, shuffle=False)


In [None]:
from src.training.pretrainedModels import EuroSatPreTrainedModel
from src.training.cnn import LitEuroSatCnn
from config import Config
from src.training.pretrainedModels import get_pretrained_model
from pytorch_lightning import Trainer

config = Config()

# model_eval = LitEuroSatCnn.load_from_checkpoint(
#     "checkpoints/cnn/14-18/cnn_c87654_k53_lr0.03_m0.9_g0.8epoch=09-val_loss=0.68.ckpt", #checkpoint_callback.best_model_path,
#     num_classes=NUM_CLASSES,
#     learning_rate=0.025, 
#     num_channels=len(CHANNELS), 
#     kernel_size=[5, 3],
#     momentum=0.9,
#     gamma=0.9,
#     weights=torch.tensor(w, dtype=torch.float32)
# )
model, _ = get_pretrained_model(PRETRAINED_MODEL)
model_eval = EuroSatPreTrainedModel.load_from_checkpoint(
    "checkpoints/resnet50_RGB/12-57/epoch=04-val_loss=1.52-train_loss_epoch=0.86.ckpt", #checkpoint_callback.best_model_path,
    backbone=model,
    learning_rate=1e-4,
    layers=[],
    momentum=0.9,
    dropout=0,
    weight_decay=0.0001
)

trainer = Trainer(
    max_epochs=4,
    accelerator="gpu", 
    devices=1,
    logger=None,
    callbacks=[],
)

trainer.test(model_eval, datamodule=data_module)

model_eval.eval()

In [None]:
import torch
import numpy as np

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_eval = model_eval.to(device)
model_eval.eval()

N_CLASSES = 10
categorys = dataset.enc.categories_[0]
print(categorys)

predictions = []
probabilities = []
ohe = []
images = []
sample_ids = []

    
with torch.no_grad():
    for batch in dataloader:
        inputs, samp_id = batch
        inputs = inputs.to(device)
            
        outputs = model_eval(inputs)
        _, preds = torch.max(outputs, 1)
        
        preds = np.array(preds.cpu().numpy())
        
        pred_labels = np.array([categorys[p] for p in preds])
        
        predictions.extend(pred_labels)
        images.extend(inputs.cpu())
        sample_ids.extend(samp_id.cpu())
    


In [None]:
import pandas as pd

sub_df = pd.DataFrame({'test_id': np.array(sample_ids), 'label': np.array(predictions)})
sub_df = sub_df.sort_values(by='test_id')
print(sub_df.head())
print(np.array(sample_ids))

sub_df.to_csv('submission.csv', index=False)
print(np.unique(predictions, return_counts=True))

In [None]:
from torch.nn.functional import interpolate

# [ 606,  550,  229,  628,  216,  182,  280,  447,   89, 1005]
# [ 876,  601,  219,  582,  195,   55,  164,  458,   63, 1019]
def overlay_cam_on_image(im, cam_mask):
    cam_mask = (cam_mask - cam_mask.min()) / (cam_mask.max() - cam_mask.min())
    print(cam_mask.shape)

    print(im.shape)
    # Resize the CAM mask to match the image size
    cam_mask = interpolate(cam_mask, size=im.shape, mode='nearest').squeeze(0)
    print(im.shape)
    print(cam_mask.shape)
    # Convert CAM mask to heatmap
    heatmap = plt.get_cmap('jet')(cam_mask.cpu().detach().numpy())[:, :, :3]  # Get the RGB part, discard alpha
    heatmap = torch.from_numpy(heatmap).permute(2, 0, 1).float()

    # Overlay the heatmap on the image
    combined_img = heatmap * 0.3 + im.cpu() * 0.5  # Adjust opacity as needed

    return combined_img

In [None]:
from pytorch_grad_cam import GradCAM

import torch
import random
from matplotlib import pyplot as plt


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
grad_cam_c = False
samp_batch_idx = [i for i in range(0, len(sample_ids))]
random.shuffle(samp_batch_idx)
samp_batch_idx = np.array(samp_batch_idx)
n = 10

for param in model_eval.backbone.parameters():
    param.requires_grad = True


# target_layer = [model_eval.backbone.layer4[-1].conv1]


for batch_start in range(0, n*8, 8):  # Iterate in steps of 8
    fig, axs = plt.subplots(2, 4, figsize=(20, 10))  # Create a new figure for each batch
    axs = axs.flatten()  # Flatten the grid for easy iteration

    for idx, ax in zip(samp_batch_idx[batch_start:batch_start+8], axs):
        pred = predictions[idx]
        samp_id = sample_ids[idx]
        im_path = config.DATA_DIR + f"test/NoLabel/test_{samp_id}.npy"
        # img = images[idx].unsqueeze(0).requires_grad_(True).to(device)
        img = np.load(im_path).transpose(2, 0, 1)
        img = img[[3, 2, 1]].astype(np.float32)
        
        rgb_min, rgb_max = img.min(), img.max()
        img = (img - rgb_min) / (rgb_max - rgb_min)
        img = img.clip(0, 1)
        
        
        if grad_cam_c:
            pass
            # img = images[idx].unsqueeze(0).requires_grad_(True).to(device)
            # cam = GradCAM(model=model_eval, target_layers=target_layer)
            # grayscale_cam = cam(input_tensor=img, targets=None)
            # cam_mask_tensor = torch.tensor(grayscale_cam).unsqueeze(0)
            # ax.imshow(cam_mask_tensor.squeeze(0).squeeze(0).cpu().numpy(), cmap='jet', alpha=0.1)
        else:
            ax.imshow(img.transpose(1, 2, 0))
        ax.set_title(pred, fontsize=20)
        ax.axis('off')
        

    plt.tight_layout()
    plt.subplots_adjust(wspace=0.2, hspace=0.2)
    plt.show()