In [None]:
!git clone https://github.com/MockaWolke/featout_hackathon
%cd featout_hackathon
!bash load_data.sh
!pip install -e .
!pip install captum

Cloning into 'featout_hackathon'...
remote: Enumerating objects: 145, done.[K
remote: Counting objects: 100% (54/54), done.[K
remote: Compressing objects: 100% (40/40), done.[K
remote: Total 145 (delta 20), reused 44 (delta 14), pack-reused 91[K
Receiving objects: 100% (145/145), 32.43 MiB | 18.08 MiB/s, done.
Resolving deltas: 100% (61/61), done.
/content/featout_hackathon
--2023-09-14 16:55:05--  https://download.microsoft.com/download/3/E/1/3E1C3F21-ECDB-4869-8368-6DEBA77B919F/kagglecatsanddogs_5340.zip
Resolving download.microsoft.com (download.microsoft.com)... 23.6.204.160, 2600:1407:3c00:a86::317f, 2600:1407:3c00:a93::317f
Connecting to download.microsoft.com (download.microsoft.com)|23.6.204.160|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 824887076 (787M) [application/octet-stream]
Saving to: ‘kagglecatsanddogs_5340.zip’


2023-09-14 16:55:11 (136 MB/s) - ‘kagglecatsanddogs_5340.zip’ saved [824887076/824887076]

Obtaining file:///content/featou

In [None]:
import os
import time
import numpy as np

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
import csv
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from featout_exp.utils import load_data, load_model, load_data_raw_dataset
from featout_exp.ours import *

In [None]:
BATCHSIZE = 32
NUM_EPOCHS = 10
EXP_NAME = "Ours_Featout"
TESTING = False
BLUR_METHOD = blur_featurs
# algorithm to derive the model's attention
ATTENTION_ALGORITHM = simple_gradient_saliency

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = load_model(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())

train_dataset = load_data_raw_dataset("train")

train_featout = Featout(train_dataset, "test_images", do_plotting= False, device= device)

train_loader = torch.utils.data.DataLoader(
    train_featout, batch_size= 32, shuffle=True, num_workers=0
)


test_loader = load_data("test", 32, False)


EXP_PATH = os.path.join("logs",EXP_NAME)

os.makedirs(EXP_PATH, exist_ok=True)
# Initialize TensorBoard writer
writer = SummaryWriter(EXP_PATH)

# Prepare a CSV file for logging
csv_file = open(os.path.join(EXP_PATH, "logs.csv"), "w")
csv_writer = csv.writer(csv_file)
csv_writer.writerow(["Epoch", "Train Loss", "Train Accuracy", "Test Loss", "Test Accuracy"])

57

In [None]:
def exploration_rate(epoch):
    exp_rate = epoch / NUM_EPOCHS
    if exp_rate > 1:
        return 1
    return exp_rate

In [None]:
for epoch in range(NUM_EPOCHS):
    model.train()
    running_loss = 0.0
    corrects = 0
    total = 0

    bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}")

    if epoch > 0:
        train_loader.dataset.start_featout(
            model,
            blur_method=BLUR_METHOD,
            algorithm=ATTENTION_ALGORITHM,
            exp_rate=exploration_rate(epoch),
            max_clusters=10
        )

    for step, (inputs, labels) in enumerate(bar):


        if TESTING == True and step > 2:
            break

        inputs, labels = inputs.to(device), labels.to(device)

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)
        loss = criterion(outputs, labels)

        # Backward pass and optimize
        loss.backward()
        optimizer.step()

        # Statistics
        running_loss += loss.item() * inputs.size(0)
        corrects += torch.sum(preds == labels.data)
        total += labels.size(0)

    train_loss = running_loss / total
    train_acc = corrects.double() / total

    # Evaluate on test data
    model.eval()
    test_running_loss = 0.0
    test_corrects = 0
    test_total = 0

    with torch.no_grad():
        for step,  (inputs, labels) in enumerate(tqdm( test_loader)):
            if TESTING == True and step > 2:
                 break

            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            loss = criterion(outputs, labels)

            test_running_loss += loss.item() * inputs.size(0)
            test_corrects += torch.sum(preds == labels.data)
            test_total += labels.size(0)

    test_loss = test_running_loss / test_total
    test_acc = test_corrects.double() / test_total

    # Log metrics to TensorBoard
    writer.add_scalar("Loss/train", train_loss, epoch)
    writer.add_scalar("Accuracy/train", train_acc, epoch)
    writer.add_scalar("Loss/test", test_loss, epoch)
    writer.add_scalar("Accuracy/test", test_acc, epoch)

    # Log metrics to CSV
    csv_writer.writerow([epoch + 1, train_loss, train_acc.item(), test_loss, test_acc.item()])
    print(f"Epoch {epoch+1}/{NUM_EPOCHS} - Train Loss: {train_loss:.4f} Train Acc: {train_acc:.4f} - Test Loss: {test_loss:.4f} Test Acc: {test_acc:.4f}")


csv_file.close()
writer.close()


Epoch 1/10: 100%|██████████| 625/625 [01:51<00:00,  5.58it/s]
100%|██████████| 157/157 [00:18<00:00,  8.33it/s]


Epoch 1/10 - Train Loss: 0.4799 Train Acc: 0.7665 - Test Loss: 1.3895 Test Acc: 0.4584


Epoch 2/10:   0%|          | 1/625 [00:00<01:32,  6.73it/s]


 STARTING FEATOUT 
 


Epoch 2/10:   0%|          | 2/625 [00:00<05:48,  1.79it/s]


 STARTING FEATOUT 
 


Epoch 2/10:   0%|          | 3/625 [00:01<06:17,  1.65it/s]


 STARTING FEATOUT 
 


Epoch 2/10:   1%|          | 4/625 [00:02<07:03,  1.47it/s]


 STARTING FEATOUT 
 


Epoch 2/10:   1%|          | 5/625 [00:03<07:29,  1.38it/s]


 STARTING FEATOUT 
 


Epoch 2/10:   1%|          | 6/625 [00:04<07:36,  1.35it/s]


 STARTING FEATOUT 
 


Epoch 2/10:   1%|          | 7/625 [00:04<07:31,  1.37it/s]


 STARTING FEATOUT 
 


Epoch 2/10:   1%|▏         | 8/625 [00:05<07:39,  1.34it/s]


 STARTING FEATOUT 
 


Epoch 2/10:   1%|▏         | 9/625 [00:06<07:47,  1.32it/s]


 STARTING FEATOUT 
 


Epoch 2/10:   2%|▏         | 10/625 [00:07<08:36,  1.19it/s]


 STARTING FEATOUT 
 


Epoch 2/10:   2%|▏         | 11/625 [00:08<08:46,  1.17it/s]


 STARTING FEATOUT 
 


Epoch 2/10:   2%|▏         | 12/625 [00:09<08:51,  1.15it/s]


 STARTING FEATOUT 
 


Epoch 2/10:   2%|▏         | 13/625 [00:09<08:32,  1.19it/s]


 STARTING FEATOUT 
 


Epoch 2/10:   2%|▏         | 14/625 [00:11<09:44,  1.05it/s]


 STARTING FEATOUT 
 


Epoch 2/10:   2%|▏         | 15/625 [00:11<09:08,  1.11it/s]


 STARTING FEATOUT 
 


Epoch 2/10:   3%|▎         | 16/625 [00:12<08:43,  1.16it/s]


 STARTING FEATOUT 
 


Epoch 2/10:   3%|▎         | 17/625 [00:13<08:28,  1.19it/s]


 STARTING FEATOUT 
 


Epoch 2/10:   3%|▎         | 18/625 [00:14<08:27,  1.20it/s]


 STARTING FEATOUT 
 


Epoch 2/10:   3%|▎         | 19/625 [00:15<08:39,  1.17it/s]


 STARTING FEATOUT 
 


Epoch 2/10:   3%|▎         | 20/625 [00:16<09:20,  1.08it/s]


 STARTING FEATOUT 
 


Epoch 2/10:   3%|▎         | 21/625 [00:17<08:48,  1.14it/s]


 STARTING FEATOUT 
 


Epoch 2/10:   4%|▎         | 22/625 [00:17<08:34,  1.17it/s]


 STARTING FEATOUT 
 


Epoch 2/10:   4%|▎         | 23/625 [00:18<09:11,  1.09it/s]


 STARTING FEATOUT 
 


Epoch 2/10:   4%|▍         | 24/625 [00:19<09:36,  1.04it/s]


 STARTING FEATOUT 
 


Epoch 2/10:   4%|▍         | 25/625 [00:21<10:01,  1.00s/it]


 STARTING FEATOUT 
 


Epoch 2/10:   4%|▍         | 26/625 [00:21<09:28,  1.05it/s]


 STARTING FEATOUT 
 


Epoch 2/10:   4%|▍         | 27/625 [00:22<08:54,  1.12it/s]


 STARTING FEATOUT 
 


Epoch 2/10:   4%|▍         | 28/625 [00:23<08:25,  1.18it/s]


 STARTING FEATOUT 
 


Epoch 2/10:   5%|▍         | 29/625 [00:24<08:21,  1.19it/s]


 STARTING FEATOUT 
 


Epoch 2/10:   5%|▍         | 30/625 [00:24<08:00,  1.24it/s]


 STARTING FEATOUT 
 


Epoch 2/10:   5%|▍         | 31/625 [00:25<07:52,  1.26it/s]


 STARTING FEATOUT 
 


Epoch 2/10:   5%|▌         | 32/625 [00:26<07:45,  1.27it/s]


 STARTING FEATOUT 
 


Epoch 2/10:   5%|▌         | 33/625 [00:27<07:51,  1.26it/s]


 STARTING FEATOUT 
 


Epoch 2/10:   5%|▌         | 34/625 [00:28<07:50,  1.26it/s]


 STARTING FEATOUT 
 


Epoch 2/10:   6%|▌         | 35/625 [00:28<07:46,  1.27it/s]


 STARTING FEATOUT 
 


Epoch 2/10:   6%|▌         | 36/625 [00:29<07:54,  1.24it/s]


 STARTING FEATOUT 
 


Epoch 2/10:   6%|▌         | 37/625 [00:30<07:47,  1.26it/s]


 STARTING FEATOUT 
 


Epoch 2/10:   6%|▌         | 38/625 [00:31<07:55,  1.23it/s]


 STARTING FEATOUT 
 


Epoch 2/10:   6%|▌         | 39/625 [00:32<09:27,  1.03it/s]


 STARTING FEATOUT 
 


Epoch 2/10:   6%|▋         | 40/625 [00:33<10:10,  1.04s/it]


 STARTING FEATOUT 
 


Epoch 2/10:   7%|▋         | 41/625 [00:34<09:27,  1.03it/s]


 STARTING FEATOUT 
 


Epoch 2/10:   7%|▋         | 42/625 [00:35<08:59,  1.08it/s]


 STARTING FEATOUT 
 


Epoch 2/10:   7%|▋         | 43/625 [00:36<08:43,  1.11it/s]


 STARTING FEATOUT 
 


Epoch 2/10:   7%|▋         | 44/625 [00:37<08:18,  1.17it/s]


 STARTING FEATOUT 
 


Epoch 2/10:   7%|▋         | 45/625 [00:38<08:47,  1.10it/s]


 STARTING FEATOUT 
 


Epoch 2/10:   7%|▋         | 46/625 [00:38<08:22,  1.15it/s]


 STARTING FEATOUT 
 


Epoch 2/10:   8%|▊         | 47/625 [00:39<08:02,  1.20it/s]


 STARTING FEATOUT 
 


Epoch 2/10:   8%|▊         | 48/625 [00:40<07:53,  1.22it/s]


 STARTING FEATOUT 
 


Epoch 2/10:   8%|▊         | 49/625 [00:41<07:55,  1.21it/s]


 STARTING FEATOUT 
 


Epoch 2/10:   8%|▊         | 50/625 [00:42<07:40,  1.25it/s]


 STARTING FEATOUT 
 


Epoch 2/10:   8%|▊         | 51/625 [00:42<07:38,  1.25it/s]


 STARTING FEATOUT 
 


Epoch 2/10:   8%|▊         | 52/625 [00:43<07:39,  1.25it/s]


 STARTING FEATOUT 
 


Epoch 2/10:   8%|▊         | 53/625 [00:44<08:01,  1.19it/s]


 STARTING FEATOUT 
 


Epoch 2/10:   9%|▊         | 54/625 [00:45<08:43,  1.09it/s]


 STARTING FEATOUT 
 


Epoch 2/10:   9%|▉         | 55/625 [00:46<08:59,  1.06it/s]


 STARTING FEATOUT 
 


Epoch 2/10:   9%|▉         | 56/625 [00:47<08:33,  1.11it/s]


 STARTING FEATOUT 
 


Epoch 2/10:   9%|▉         | 57/625 [00:48<08:07,  1.17it/s]


 STARTING FEATOUT 
 


Epoch 2/10:   9%|▉         | 58/625 [00:49<07:57,  1.19it/s]


 STARTING FEATOUT 
 


Epoch 2/10:   9%|▉         | 59/625 [00:49<07:46,  1.21it/s]


 STARTING FEATOUT 
 


Epoch 2/10:  10%|▉         | 60/625 [00:50<07:29,  1.26it/s]


 STARTING FEATOUT 
 


Epoch 2/10:  10%|▉         | 61/625 [00:51<07:27,  1.26it/s]


 STARTING FEATOUT 
 


Epoch 2/10:  10%|▉         | 62/625 [00:51<07:07,  1.32it/s]


 STARTING FEATOUT 
 


Epoch 2/10:  10%|█         | 63/625 [00:52<07:03,  1.33it/s]


 STARTING FEATOUT 
 


Epoch 2/10:  10%|█         | 64/625 [00:53<07:01,  1.33it/s]


 STARTING FEATOUT 
 


Epoch 2/10:  10%|█         | 65/625 [00:54<06:56,  1.34it/s]


 STARTING FEATOUT 
 


Epoch 2/10:  11%|█         | 66/625 [00:54<06:59,  1.33it/s]


 STARTING FEATOUT 
 


Epoch 2/10:  11%|█         | 67/625 [00:55<06:54,  1.35it/s]


 STARTING FEATOUT 
 


In [None]:
torch.save(model.state_dict(), os.path.join(EXP_PATH, "model.pth"))

In [None]:
!zip -r Hers.zip logs/Standard_Featout


In [None]:
from google.colab import files
files.download('Hers.zip')