In [2]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import shap
import numpy as np
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
# Path to the dataset
data_path = "E:\\EQML_Project\\data_preprocessed\\log_spectrograms\\train"

# Define custom dataset
class EarthquakeDataset(Dataset):
    def __init__(self, data_path, transform=None):
        self.data_path = data_path
        self.transform = transform
        self.file_list = [f for f in os.listdir(data_path) if f.endswith(".png")]

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        img_name = self.file_list[idx]
        img_path = os.path.join(self.data_path, img_name)
        image = Image.open(img_path).convert("RGB")

        # Extract label from filename
        if "_post.png" in img_name:
            label = 1  # Aftershock
        elif "_pre.png" in img_name:
            label = 0  # Mainshock
        else:
            raise ValueError("Filename does not match expected pattern")

        if self.transform:
            image = self.transform(image)

        return image, label

# Transformations for the images
transform = transforms.Compose([
    transforms.ToTensor(),
])

In [5]:
# Create dataset and dataloader
dataset = EarthquakeDataset(data_path, transform=transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)

In [6]:
# Define the CNN model
class CNN2D(nn.Module):
    def __init__(self, num_classes):
        super(CNN2D, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.BatchNorm1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)
        self.BatchNorm2 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.BatchNorm3 = nn.BatchNorm2d(64)
        self.conv4 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.BatchNorm4 = nn.BatchNorm2d(64)
        self.conv5 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.BatchNorm5 = nn.BatchNorm2d(128)
        self.conv6 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.BatchNorm6 = nn.BatchNorm2d(256)
        self.conv7 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.BatchNorm7 = nn.BatchNorm2d(256)
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.3)
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(256 * 4 * 2, num_classes)  # Adjust based on input size

    def forward(self, x):
        x = self.relu(self.maxpool(self.BatchNorm1(self.conv1(x))))
        x = self.relu(self.maxpool(self.BatchNorm2(self.conv2(x))))
        x = self.relu(self.maxpool(self.BatchNorm3(self.conv3(x))))
        x = self.relu(self.maxpool(self.BatchNorm4(self.conv4(x))))
        x = self.relu(self.maxpool(self.BatchNorm5(self.conv5(x))))
        x = self.relu(self.maxpool(self.BatchNorm6(self.conv6(x))))
        x = self.relu(self.maxpool(self.BatchNorm7(self.conv7(x))))
        x = self.dropout(self.flatten(x))
        x = self.fc(x)
        return x

In [7]:
# Initialize model, loss, and optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CNN2D(num_classes=2).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [3]:
torch.cuda.is_available()

False

In [8]:
# Training loop
def train_model(model, dataloader, criterion, optimizer, num_epochs=10):
    for epoch in tqdm(range(num_epochs)):
        model.train()
        running_loss = 0.0

        for inputs, labels in tqdm(dataloader, desc = "dataloaders"):
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()

            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * inputs.size(0)

        epoch_loss = running_loss / len(dataloader.dataset)
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}")

In [9]:
# Train the model
train_model(model, dataloader, criterion, optimizer, num_epochs=10)

  0%|          | 0/10 [00:00<?, ?it/s]

: 

: 

In [None]:
# SHAP explainability
model.eval()

# Select a batch of images to explain
images, _ = next(iter(dataloader))
images = images.to(device)

# Define a SHAP explainer
def predict(images):
    with torch.no_grad():
        logits = model(images)
        probs = nn.Softmax(dim=1)(logits)
    return probs.cpu().numpy()

explainer = shap.DeepExplainer(model, images)
shap_values = explainer.shap_values(images)

# Visualize SHAP explanations
shap.image_plot(shap_values, images.cpu().numpy())
