In [17]:
from PIL import Image
import torch
import os
import pandas as pd
import numpy as np
from tqdm.notebook import tqdm
from torch.utils.data import Dataset
from pytorch_grad_cam.utils.image import show_cam_on_image
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
import pandas as pd
from torchvision import transforms
from torchvision.models import resnet152
from torch import nn

os.environ['HF_HOME'] = os.getcwd()
os.environ['TRANSFORMERS_CACHE'] = os.getcwd()
os.environ['TORCH_HOME'] = os.getcwd()
os.environ['HSA_OVERRIDE_GFX_VERSION'] = '10.3.0'
os.environ['HIP_VISIBLE_DEVICES'] = '0'

In [2]:
seed = 42
if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
torch.manual_seed(seed)
np.random.seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DEVICE

device(type='cuda')

In [4]:
class WaterbirdDataset(Dataset):
    def __init__(self, df, root_dir, transform=None):
        self.root_dir = root_dir
        self.df = df
        self.transform = transform

    
    def __len__(self):
        return len(self.df)
    
    
    def __getitem__(self, idx):
        img_filename = self.df.iloc[idx]['img_filename']
        img = Image.open(os.path.join(self.root_dir, img_filename)).convert("RGB")
        if self.transform:
            img = self.transform(img)
        
        label = self.df.iloc[idx]['y']
        place = self.df.iloc[idx]['place']
            
        return img_filename, img, label, place

In [6]:
root_dir = './datasets/waterbird'

df = pd.read_csv(os.path.join(root_dir, 'metadata.csv'))
test_df = df[df['split'] == 2].reset_index(inplace=False)

In [8]:
transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

test_set = WaterbirdDataset(df=test_df, root_dir=root_dir, transform=transform)

In [18]:
model = resnet152(weights=None)
model.fc = nn.Linear(model.fc.in_features, 2)
model = model.to(DEVICE)
model.load_state_dict(torch.load("/media/atiqur/Extra/Download/Waterbird/Resnet152/800_sample_1/resnet152_augmented_800_sample_1.pth", weights_only=True, map_location=DEVICE))

<All keys matched successfully>

In [21]:
target_layers = [model.layer4[-1]]
cam = GradCAM(model=model, target_layers=target_layers)

In [29]:
metadata = {
    'img_filename': [],
    'cam_filename': [],
    'y': [],
    'place': [],
    'prediction': []
}

saving_dir = '/media/atiqur/Extra/Download/Waterbird/Resnet152/Grad-cam'

for i in tqdm(range(test_set.__len__())):
    torch.cuda.empty_cache()
    img_filename, inp, label, place = test_set.__getitem__(i)
    inp = inp.unsqueeze(0).to(DEVICE)

    with torch.no_grad():
        logits = model(inp)
        pred = torch.argmax(logits, dim=1).cpu().item()

    target = [ClassifierOutputTarget(label)]
    
    rgb_img = Image.open(os.path.join(root_dir, img_filename)).convert('RGB').resize((224, 224))
    rgb_img = np.float32(rgb_img) / 255

    cam_out = cam(input_tensor=inp, targets=target)
    cam_out = cam_out[0, :]
    visualization = show_cam_on_image(rgb_img, cam_out, use_rgb=True)

    cam_filename = f'{i}_{img_filename.split('/')[-1]}'
    metadata['img_filename'].append(img_filename)
    metadata['cam_filename'].append(cam_filename)
    metadata['y'].append(label)
    metadata['place'].append(place)
    metadata['prediction'].append(pred)

    img = Image.fromarray(visualization)
    img.save(os.path.join(saving_dir, cam_filename))
    # break

  0%|          | 0/5794 [00:00<?, ?it/s]

In [30]:
metadata_df = pd.DataFrame(metadata)
metadata_df.to_csv(os.path.join(saving_dir, 'metadata.csv'), index=False)