## Use Case: Spurious Correlations in EuroSAT

In [3]:
import torch
from torchsat.models import resnet50
import torchsat.transforms.functional as F
import numpy as np
import pandas as pd

import argparse
import os

import torch.nn as nn
import torch.nn.utils.prune as prune
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

import torchsat.transforms.transforms_cls as T_cls
from torchsat.datasets.folder import ImageFolder
from torchsat.models.utils import get_model

from evaluate import load

from collections import Counter
import re
from sentence_transformers import SentenceTransformer

os.chdir("sandbox-DnD")
import clip

### Train ResNet50 model on EuroSAT dataset

In [4]:
# NOTE: For replication purposes, our used .pth file is provided in the GitHub, so DO NOT run this cell to replicate our results

'''
!mkdir output
!python scripts/train_cls.py \
         --train-path EuroSAT/train \
         --val-path EuroSAT/val/ \
         --model resnet50 \
         --num-classes 10 \
         --device cuda \
         -b 64 \
         --print-freq 20 \
         --ckp-di output
'''

In [5]:
class_to_idx = {
        'AnnualCrop': 0,
        'Forest': 1,
        'HerbaceousVegetation': 2,
        'Highway': 3,
        'Industrial': 4,
        'Pasture': 5,
        'PermanentCrop': 6,
        'Residential': 7,
        'River': 8,
        'SeaLake': 9,
    }
idx_to_class = {v: k for k, v in class_to_idx.items()}

# load trained model
model_sat = resnet50(num_classes=10)
ckp = 'output/cls_epoch_35.pth'
model_sat.load_state_dict(torch.load(ckp, map_location=torch.device('cuda')))

<All keys matched successfully>

In [6]:
def evaluate(epoch, model, criterion, data_loader, device):
    model.eval()
    model.to(device)
    loss = 0
    correct = 0
    class_correct = [0 for _ in range(10)]
    class_count = [0 for _ in range(10)]
    with torch.no_grad():
        for idx, (image, target) in enumerate(data_loader):
            image = image.to(device, non_blocking=True)
            target = target.to(device, non_blocking=True)
            output = model(image)
            loss += criterion(output, target).item()

            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()
            
            for i in range(len(target)):
                label = target[i].item()
                prediction = pred[i].item()
                if prediction == label:
                    class_correct[label] += 1
                class_count[label] += 1

        loss /= len(data_loader.dataset)/data_loader.batch_size

        print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
            loss, correct, len(data_loader.dataset),
            100. * correct / len(data_loader.dataset)))
        
        class_names = [
            'AnnualCrop', 'Forest', 'HerbaceousVegetation', 'Highway',
            'Industrial', 'Pasture', 'PermanentCrop', 'Residential',
            'River', 'SeaLake'
        ]

        for i in range(10):
            if class_count[i] > 0:
                accuracy = 100. * class_correct[i] / class_count[i]
                print('\n{}: Accuracy: {}/{} ({:.0f}%)\n'.format(
                    class_names[i], class_correct[i], class_count[i], accuracy))
            else:
                print('\n{}: No samples found\n'.format(class_names[i]))


val_transform = T_cls.Compose([
        T_cls.ToTensor(),
        T_cls.Normalize(),
    ])
dataset_val = ImageFolder("EuroSAT/val/", val_transform)

# Evaluate on original model

evaluate(35, model_sat, nn.CrossEntropyLoss(), DataLoader(dataset_val, batch_size=16, shuffle=True), torch.device('cuda'))


Test set: Average loss: 0.1552, Accuracy: 2572/2700 (95%)


AnnualCrop: Accuracy: 285/300 (95%)


Forest: Accuracy: 294/300 (98%)


HerbaceousVegetation: Accuracy: 281/300 (94%)


Highway: Accuracy: 240/250 (96%)


Industrial: Accuracy: 242/250 (97%)


Pasture: Accuracy: 183/200 (92%)


PermanentCrop: Accuracy: 227/250 (91%)


Residential: Accuracy: 296/300 (99%)


River: Accuracy: 232/250 (93%)


SeaLake: Accuracy: 292/300 (97%)



### Load Data and Models

In [7]:
# Define embedding models

mpnetmodel = SentenceTransformer('all-mpnet-base-v2')
clip_model, _ = clip.load('ViT-B/16', device='cuda')
bertscore = load("bertscore")

In [8]:
# Load DnD labels

csv = pd.read_csv("data/DnD_results/eurosat_results/layer4.csv")
labels = list(csv["Label 1"])

### Investigate most common concept in the model

In [9]:
def most_common_concept(strings):
    # Tokenize and normalize
    tokens = []
    for str in strings:
        words = re.findall(r'\w+', str.lower())
        tokens.extend(words)
    
    # Count the frequency of each token
    token_counts = Counter(tokens)
    
    # Return the most common token
    return token_counts.most_common(1)[0]

common_concept = most_common_concept(labels)
print(f"The most common concept is '{common_concept[0]}' with {common_concept[1]} occurrences.")

The most common concept is 'fishing' with 883 occurrences.


### Define Pruning Function

In [10]:
def prune_neurons(ids_to_prune):
    layer = "layer4"
    num_neurons = eval("model_sat.{}[-1].conv3.out_channels".format(layer))
    
    for block in eval("model_sat.{}".format(layer)):
        for comp in block.children():
            if (hasattr(comp, "out_channels") and comp.out_channels == num_neurons) or (hasattr(comp, "num_features") and comp.num_features == num_neurons):
                if comp.bias != None:
                    mask_bias = torch.ones_like(comp.bias)
                    for id in ids_to_prune:
                        mask_bias[id] = 0
                    comp = torch.nn.utils.prune.custom_from_mask(comp, "bias", mask_bias)
                mask_weight = torch.ones_like(comp.weight)
                for id in ids_to_prune:
                    mask_weight[id] = 0
                comp = torch.nn.utils.prune.custom_from_mask(comp, "weight", mask_weight)
            if hasattr(comp, "__iter__"):
                for ds_comp in comp:
                    if ds_comp.bias != None:
                        mask_bias = torch.ones_like(ds_comp.bias)
                        for id in ids_to_prune:
                            mask_bias[id] = 0
                        ds_comp = torch.nn.utils.prune.custom_from_mask(ds_comp, "bias", mask_bias)
                    mask_weight = torch.ones_like(ds_comp.weight)
                    for id in ids_to_prune:
                        mask_weight[id] = 0
                    ds_comp = torch.nn.utils.prune.custom_from_mask(ds_comp, "weight", mask_weight)

### Prune "fishing" neurons

In [11]:
words = ["fishing"]

ids_to_prune = []
for idx in range(len(labels)):
    prune = False
    for word in words:
        if word in labels[idx].lower():
            prune = True
    if prune:
        ids_to_prune.append(idx)

prune_neurons(ids_to_prune)

evaluate(35, model_sat, nn.CrossEntropyLoss(), DataLoader(dataset_val, batch_size=16, shuffle=True), torch.device('cuda'))


Test set: Average loss: 0.1563, Accuracy: 2572/2700 (95%)


AnnualCrop: Accuracy: 285/300 (95%)


Forest: Accuracy: 294/300 (98%)


HerbaceousVegetation: Accuracy: 281/300 (94%)


Highway: Accuracy: 240/250 (96%)


Industrial: Accuracy: 242/250 (97%)


Pasture: Accuracy: 183/200 (92%)


PermanentCrop: Accuracy: 227/250 (91%)


Residential: Accuracy: 296/300 (99%)


River: Accuracy: 232/250 (93%)


SeaLake: Accuracy: 292/300 (97%)



### Prune "purple" and "pink" neurons

In [12]:
# Reload model
model_sat = resnet50(num_classes=10)
ckp = 'output/cls_epoch_35.pth'
model_sat.load_state_dict(torch.load(ckp, map_location=torch.device('cuda')))

words = ["purple", "pink"]

ids_to_prune = []
for idx in range(len(labels)):
    prune = False
    for word in words:
        if word in labels[idx].lower():
            prune = True
    if prune:
        ids_to_prune.append(idx)

prune_neurons(ids_to_prune)

evaluate(35, model_sat, nn.CrossEntropyLoss(), DataLoader(dataset_val, batch_size=16, shuffle=True), torch.device('cuda'))


Test set: Average loss: 9.5645, Accuracy: 657/2700 (24%)


AnnualCrop: Accuracy: 265/300 (88%)


Forest: Accuracy: 0/300 (0%)


HerbaceousVegetation: Accuracy: 0/300 (0%)


Highway: Accuracy: 155/250 (62%)


Industrial: Accuracy: 0/250 (0%)


Pasture: Accuracy: 0/200 (0%)


PermanentCrop: Accuracy: 47/250 (19%)


Residential: Accuracy: 0/300 (0%)


River: Accuracy: 190/250 (76%)


SeaLake: Accuracy: 0/300 (0%)

