In [18]:
#%%
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from seg_dataset import SegmentationDataset, RandomFlipRotate
import segmentation_models_pytorch as smp 
import seaborn as sns
import matplotlib.pyplot as plt
from segformer_pytorch import Segformer
import torch.nn.functional as F
import os 
import glob
from torchmetrics.classification.jaccard import MulticlassJaccardIndex as jaccard

from matplotlib.colors import ListedColormap, LinearSegmentedColormap
from captum.attr import LayerGradCam, LayerAttribution, visualization as viz
from captum.attr import IntegratedGradients, Occlusion

import os
from skimage import exposure
import time

In [3]:
wget https://huggingface.co/ibm-nasa-geospatial/Prithvi-100M/resolve/main/Prithvi_100M.pt?download=true -O Prithvi_100M.pt
wget https://huggingface.co/ibm-nasa-geospatial/Prithvi-100M/resolve/main/Prithvi_100M_config.yaml?download=true -O Prithvi_100M_config.yaml

SyntaxError: invalid decimal literal (2876559065.py, line 1)

In [11]:
! wget https://huggingface.co/ibm-nasa-geospatial/Prithvi-100M/resolve/main/Prithvi_100M.pt

zsh:1: command not found: wget


In [15]:
cfg_path = '/Users/sebrah13/Desktop/python_codes/prithvi-pytorch/tests/Prithvi_100M_config.yaml'
ckpt_path = "/Users/sebrah13/Desktop/python_codes/RSclass_LCLUC/weights/Prithvi_100M.pt"

In [20]:
from prithvi_pytorch import PrithviViT
model = PrithviViT(
    ckpt_path=ckpt_path,  # path to pretrained checkpoint Prithvi_100M.pt
    cfg_path=cfg_path,  # path to pretrained config Prithvi_100M_config.yaml
    num_classes=12,  # num classifier classes
    in_chans=8,  # right now only supports the pretrained 6 channels
    img_size=64,  # supports other image sizes than 224
    freeze_encoder=True  # freeze the pretrained prithvi if you just want to linear probe
)

In [21]:
X = torch.randn(1, 8, 64, 64)
out = model(X)
print(out.shape)

torch.Size([1, 12])


In [25]:
from prithvi_pytorch import PrithviEncoderDecoder

model = PrithviEncoderDecoder(
    ckpt_path=ckpt_path,  # path to pretrained checkpoint Prithvi_100M.pt
    cfg_path=cfg_path,  # path to pretrained config Prithvi_100M_config.yaml
    num_classes=9,  # num classifier classes
    in_chans=8,  # right now only supports the pretrained 6 channels
    img_size=64,  # supports other image sizes than 224
    freeze_encoder=True  # freeze the pretrained prithvi
)

In [26]:
X = torch.randn(1, 8, 64, 64)
out = model(X)
print(out.shape)

torch.Size([1, 10, 64, 64])


In [None]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# %%
EPOCHS = 2
BS = 128

In [None]:
# #%% Instantiate Dataset and Dataloader
train_ds = SegmentationDataset(data_path='/scratch/sebrah13/RS_class/yearlyImage/Train')
# sampler = torch.utils.data.WeightedRandomSampler(train_ds.weights, len(train_ds.weights))
train_dataloader = DataLoader(train_ds, batch_size=BS, pin_memory=True)
val_ds = SegmentationDataset(data_path='/scratch/sebrah13/RS_class/yearlyImage/Val')
# sampler1 = torch.utils.data.WeightedRandomSampler(val_ds.weights, len(val_ds.weights))
val_dataloader = DataLoader(val_ds, batch_size=BS, pin_memory=True)

In [None]:
for DD in val_dataloader:
    print(DD['image'].shape, DD['mask'].shape)
    
    break

In [None]:
model.to(DEVICE)
optimizer = torch.optim.Adam([ 
    dict(params=model.parameters(), lr=0.001),
])



wights = glob.glob("SegFormer*.pth")
if wights:
    model.load_state_dict(torch.load(f'SegFormer_epochs_{EPOCHS}_crossentropy_state_dict.pth'))
    print("Pretrained weights loaded")
else:
    print("No pretrained weights found, intializing random weights...")    
    # %%
criterion = nn.CrossEntropyLoss().to(DEVICE)



IoU = jaccard( num_classes= 9, average='none').to(DEVICE)

train_losses, val_losses = [],[]

In [None]:
Prev_loss = 10000
min_loss = 10000
for e in range(EPOCHS):
    model.train()
    running_train_loss, running_val_loss = 0, 0
    metrics = {'iou_scores': [], 'f1_scores': [], 'f2_scores': [], 'accuracies': [], 'recalls': [], 'ious': [], 'losses': []}
    for i, data in enumerate(train_dataloader):
        #training phase
        image_i, mask_i = data['image'], data['mask']
        image = image_i.to(DEVICE)
        mask = mask_i.to(DEVICE)
        
        # reset gradients
        optimizer.zero_grad() 
        #forward
        output = model(image.float())
        # Upsample the output to match the target label size
        # output_upsampled = output
        output_upsampled =F.interpolate(output, size=mask.shape[1:], mode = 'bilinear',   align_corners=False)
        # calc losses
        train_loss = criterion(output_upsampled .float(), mask.long())

        # back propagation
        train_loss.backward()
        optimizer.step() #update weight          
        running_train_loss += train_loss.item()
        
        # adding metrics
        IoU_metric = IoU(output_upsampled .float(), mask.long())
        _, pred = torch.max(output_upsampled, 1)
        tp, fp, fn, tn = smp.metrics.get_stats(pred, mask.long(), mode='multiclass', num_classes=9)
        iou_score = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro")
        f1_score = smp.metrics.f1_score(tp, fp, fn, tn, reduction="micro")
        f2_score = smp.metrics.fbeta_score(tp, fp, fn, tn, beta=2, reduction="micro")
        accuracy = smp.metrics.accuracy(tp, fp, fn, tn, reduction="macro")
        recall = smp.metrics.recall(tp, fp, fn, tn, reduction="micro-imagewise")
        #storing the metrics in a dictionary
        metrics['ious'].append(IoU_metric)
        metrics['iou_scores'].append(iou_score)
        metrics['f1_scores'].append(f1_score)
        metrics['f2_scores'].append(f2_score)
        metrics['accuracies'].append(accuracy)
        metrics['recalls'].append(recall)
        metrics['losses'].append(train_loss.item())
        
    train_losses.append(running_train_loss) 
    
    # Compute mean of each metric
    mean_metrics =  {
    metric: np.mean([v.cpu().numpy() if isinstance(v, torch.Tensor) else v for v in values])
    for metric, values in metrics.items()
}
    print(f"Epoch: {e}, Training Mean Loss: {mean_metrics['losses']}, Mean IoU: {mean_metrics['ious']}, "
        f"Mean IoU Score: {mean_metrics['iou_scores']}, Mean F1 Score: {mean_metrics['f1_scores']}, ")
        # f"Mean F2 Score: {mean_metrics['f2_scores']}, Mean Accuracy: {mean_metrics['accuracies']}, "
        # f"Mean Recall: {mean_metrics['recalls']}")
    
    # validation
    model.eval()
    val_metrics = {'iou_scores': [], 'f1_scores': [], 'f2_scores': [], 'accuracies': [], 'recalls': [], 'ious': []}
    with torch.no_grad():
        for i, data in enumerate(val_dataloader):
            image_i, mask_i = data['image'], data['mask']
            image = image_i.to(DEVICE)
            mask = mask_i.to(DEVICE)
            #forward
            output = model(image.float())
            # output_upsampled = output
            
            output_upsampled = F.interpolate(output, size=mask.shape[1:],mode = 'bilinear', align_corners=False)
            # calc losses
            val_loss = criterion(output_upsampled.float(), mask.long())
            running_val_loss += val_loss.item()
            
            # Calculate additional metrics
            _, pred = torch.max(output_upsampled, 1)
            IoU_metric = IoU(output_upsampled.float(), mask.long())
            tp, fp, fn, tn = smp.metrics.get_stats(pred, mask.long(), mode='multiclass', num_classes=9)
            iou_score = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro")
            f1_score = smp.metrics.f1_score(tp, fp, fn, tn, reduction="micro")
            f2_score = smp.metrics.fbeta_score(tp, fp, fn, tn, beta=2, reduction="micro")
            accuracy = smp.metrics.accuracy(tp, fp, fn, tn, reduction="macro")
            recall = smp.metrics.recall(tp, fp, fn, tn, reduction="micro-imagewise")
            # Store metrics in the validation metrics dictionary
            val_metrics['ious'].append(IoU_metric)
            val_metrics['iou_scores'].append(iou_score)
            val_metrics['f1_scores'].append(f1_score)
            val_metrics['f2_scores'].append(f2_score)
            val_metrics['accuracies'].append(accuracy)
            val_metrics['recalls'].append(recall)
            
    val_losses.append(running_val_loss)
    # Compute mean of each metric and loss
    mean_val_metrics = {
    metric: np.mean([v.cpu().numpy() if isinstance(v, torch.Tensor) else v for v in values])
    for metric, values in val_metrics.items()
}
    mean_val_loss = np.mean(running_val_loss / len(val_dataloader))
    # Append the average validation loss for this epoch
    # val_losses.append(mean_val_loss)

# Log or print validation metrics and loss
    print(f"Validation Loss: {mean_val_loss}, Mean IoU: {mean_val_metrics['ious']}, "
      f"Mean IoU Score: {mean_val_metrics['iou_scores']}, Mean F1 Score: {mean_val_metrics['f1_scores']}, ")
    #   f"Mean F2 Score: {mean_val_metrics['f2_scores']}, Mean Accuracy: {mean_val_metrics['accuracies']}, "
    #   f"Mean Recall: {mean_val_metrics['recalls']}")
    
    
    
    if np.median(running_val_loss) < min_loss:
        print(f"Loss value improved from {min_loss} to {np.median(running_val_loss)}; Saving model weights...")
        torch.save(model.state_dict(), f'Prithvi_epochs_{EPOCHS}_crossentropy_state_dict.pth')
        Prev_loss = np.median(running_val_loss)
        if min_loss > Prev_loss:
            min_loss = Prev_loss
            
                # Write report to text file
        with open('report.txt', 'a') as file:  # 'a' mode for appending in case this happens multiple times
            file.write(f"Epoch: {e}, Median Validation Loss: {running_train_loss},\n")
            file.write(f"Epoch: {e}, Median Validation Loss: {running_val_loss},\n")
            file.write("Mean Validation Metrics:\n")
            for metric, value in mean_val_metrics.items():
                file.write(f"{metric}: {value}\n")
            file.write("Metrics training Criteria (if any):\n")
            for metric, value in mean_metrics.items():
                file.write(f"{metric}: {value}\n")
           
    print(f"Epoch: {e}: Train Cumulative Loss: {np.median(running_train_loss)}, Val cumulative Loss: {np.median(running_val_loss)} ")


In [None]:
#%% TRAIN LOSS
plt.figure(figsize=(10, 5))  
sns.lineplot(x = range(len(train_losses)), y= train_losses)
sns.lineplot(x = range(len(train_losses)), y= val_losses)

# Adding titles and labels
plt.title('Training vs Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()  # Show legend to identify the lines
plt.savefig('trainloss.png')
plt.show()  # Display the plotplt.show()

In [None]:
test_dir = '/scratch/sebrah13/RS_class/yearlyImage/Val'
test_ds = SegmentationDataset(data_path=test_dir)
test_dataloader = DataLoader(test_ds, batch_size=1, shuffle=True)

model_path = '/scratch/sebrah13/RS_class/LinkNET/LinkNet_epochs_200_crossentropy_state_dict.pth'


In [None]:
model.eval()
model.to(DEVICE)
model.load_state_dict(torch.load(model_path, map_location=torch.device(DEVICE)))

In [None]:
# %% Helper functions to calculate metrics
def calculate_metrics(tp, fp, fn, tn):
    epsilon = 1e-7
    f1 = 2 * tp / (2 * tp + fp + fn + epsilon)
    f2 = 5 * tp / (5 * tp + 4 * fn + fp + epsilon)
    accuracy = (tp + tn) / (tp + tn + fp + fn + epsilon)
    recall = tp / (tp + fn + epsilon)
    return f1, f2, accuracy, recall

# %% Model Evaluation
num_classes = 9
tp = np.zeros(num_classes)
fp = np.zeros(num_classes)
fn = np.zeros(num_classes)
tn = np.zeros(num_classes)  # Added true negatives

all_true_labels, all_pred_labels = [],[]

with torch.no_grad():
    for data in test_dataloader:
        inputs, outputs = data['image'], data['mask']
        true = outputs.to(torch.float32).to(DEVICE)
        pred = model(inputs.to(DEVICE).float())
        pred = F.interpolate(pred, size=true.shape[1:], mode = 'bilinear',   align_corners=False)
        _, predicted = torch.max(pred, 1)
        
        all_true_labels.extend(true.cpu().numpy().flatten())
        all_pred_labels.extend(predicted.cpu().numpy().flatten())

        for cls in range(num_classes):
            tp[cls] += torch.sum((predicted == cls) & (true == cls)).item()
            fp[cls] += torch.sum((predicted == cls) & (true != cls)).item()
            fn[cls] += torch.sum((predicted != cls) & (true == cls)).item()
            tn[cls] += torch.sum((predicted != cls) & (true != cls)).item()  # Correctly count true negatives

# Compute IoU for each class
class_iou = tp / (tp + fp + fn + 1e-7)
mean_iou = np.mean(class_iou)

# Compute additional metrics
f1_scores, f2_scores, accuracies, recalls = calculate_metrics(tp, fp, fn, tn)
mean_f1 = np.mean(f1_scores)
mean_f2 = np.mean(f2_scores)
mean_accuracy = np.mean(accuracies)
mean_recall = np.mean(recalls)

print(f"Class-wise IoUs: {class_iou}")
print(f"Mean IoU: {mean_iou}")
print(f"Mean F1 Score: {mean_f1}")
print(f"Mean F2 Score: {mean_f2}")
print(f"Mean Accuracy: {mean_accuracy}")
print(f"Mean Recall: {mean_recall}")


In [None]:
#%% Confusion Matrix
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import pandas as pd


Name = 'Linknet'

# Compute confusion matrix
conf_matrix = confusion_matrix(all_true_labels, all_pred_labels, labels=list(range(num_classes)))

# Normalize confusion matrix by row (true labels)
conf_matrix_normalized = conf_matrix.astype('float') / conf_matrix.sum(axis=1)[:, np.newaxis]

# Convert to percentage
conf_matrix_normalized *= 100
# Create a custom colormap
cmap = sns.color_palette("rocket_r", as_cmap=True)
# Plot normalized confusion matrix
plt.figure(figsize=(6, 4.5))
ax = sns.heatmap(pd.DataFrame(conf_matrix_normalized, columns=[f'{i+1}' for i in range(num_classes)], 
                         index=[f'{i+1}' for i in range(num_classes)]), 
            annot=True, fmt='.2f', cmap=cmap, vmin=0, vmax=100)

# Set font properties
plt.xlabel('Predicted', fontsize=14, fontname='Times New Roman')
plt.ylabel('True', fontsize=14, fontname='Times New Roman')
plt.title(f'{Name}'
        #   {config_name[-5:]}'
          , fontsize=18, fontname='Times New Roman')

# Set ticks font properties
ax.set_xticklabels(ax.get_xticklabels(), fontsize=14, fontname='Times New Roman')
ax.set_yticklabels(ax.get_yticklabels(), fontsize=14, fontname='Times New Roman')

# Adjust color bar font properties
colorbar = ax.collections[0].colorbar
colorbar.ax.tick_params(labelsize=14)
colorbar.ax.set_yticklabels([f'{int(i)}%' for i in colorbar.get_ticks()], fontsize=14, fontname='Times New Roman')

plt.savefig(f'normalized_confusion_matrix{Name}.png')
plt.show()

In [None]:
#%% Pick a test image and show it
from matplotlib.colors import ListedColormap
Sample = next(iter(test_dataloader))
image_test, mask = Sample['image'], Sample['mask']
plt.imshow(np.transpose(image_test[0, 0:3, :, :].cpu().numpy(), (1, 2, 0)))

#%% EVALUATE MODEL
# create preds
with torch.no_grad():
    image_test = image_test.float().to(DEVICE)
    output = model(image_test)

#%%
output_cpu = output.cpu().squeeze().numpy()
Output = output_cpu[:,:,:]
output_cpu = Output.transpose((1, 2, 0))
output_cpu = output_cpu.argmax(axis=2)

# %%



# Define a color map with 9 distinct colors for values 0 to 8
colors = ['black', 'red', 'blue', 'green', 'purple', 'orange', 'yellow', 'cyan', 'magenta']
cmap = ListedColormap(colors)

fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(10, 5))
fig.suptitle('True and Predicted Mask')

true_mask_img = axs[0].imshow(mask[0, :, :], cmap=cmap, vmin=0, vmax=8)
predicted_mask_img = axs[1].imshow(output_cpu, cmap=cmap, vmin=0, vmax=8)

# Add titles
axs[0].set_title("True Mask")
axs[1].set_title("Predicted Mask")

# Add color bar to interpret the values
fig.colorbar(true_mask_img, ax=axs, orientation='horizontal', fraction=0.05, pad=0.1, label='Class Values')

# Save and display the plot
plt.savefig('Predicted_Mask.png')
plt.show()

# %%
print(np.unique(output_cpu))
print(np.unique(mask[0, :, :].numpy()))