In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import shap
import numpy as np
from tqdm import tqdm
import time

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Path to the dataset
train_data_path = "E:\\EQML_Project\\data_preprocessed\\log_spectrograms\\train"
valid_data_path = "E:\\EQML_Project\\data_preprocessed\\log_spectrograms\\valid"

In [3]:
# Define custom dataset
class EarthquakeDataset(Dataset):
    def __init__(self, data_path, transform=None):
        self.data_path = data_path
        self.transform = transform
        self.file_list = [f for f in os.listdir(data_path) if f.endswith(".png")]

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        img_name = self.file_list[idx]
        img_path = os.path.join(self.data_path, img_name)
        image = Image.open(img_path).convert("RGB")

        # Extract label from filename
        if "_post.png" in img_name:
            label = 1  # Aftershock
        elif "_pre.png" in img_name:
            label = 0  # Mainshock
        else:
            raise ValueError("Filename does not match expected pattern")

        if self.transform:
            image = self.transform(image)

        return image, label

# Transformations for the images
transform = transforms.Compose([
    transforms.ToTensor(),
])

In [4]:
dataset = EarthquakeDataset(train_data_path, transform=transform)
valid_dataset = EarthquakeDataset(valid_data_path, transform=transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)
valid_dataloader = DataLoader(valid_dataset, batch_size=32, shuffle=False, num_workers=4)


In [5]:
class CNN2D(nn.Module):
    def __init__(self, num_classes):
        super(CNN2D, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.BatchNorm1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)
        self.BatchNorm2 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.BatchNorm3 = nn.BatchNorm2d(128)
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.relu = nn.ReLU()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(128 * 4 * 2, 128)  # Adjust based on input size
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, x):
        x = self.relu(self.maxpool(self.BatchNorm1(self.conv1(x))))
        x = self.relu(self.maxpool(self.BatchNorm2(self.conv2(x))))
        x = self.relu(self.BatchNorm3(self.conv3(x)))
        x = self.flatten(x)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [6]:
# Initialize model, loss, and optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
model = CNN2D(num_classes=2).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

Using device: cuda


In [7]:
def calculate_accuracy(model, dataloader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return correct / total

#Training loop
def train_model(model, dataloader, valid_dataloader, criterion, optimizer, num_epochs=10):
    for epoch in tqdm(range(num_epochs)):
        model.train()
        running_loss = 0.0

        total_batches = len(dataloader)  # Total number of batches in the training set

        for batch_idx, (inputs, labels) in enumerate(dataloader):
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()

            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * inputs.size(0)

            # Calculate the percentage of the training set trained on
            percent_trained = (batch_idx + 1) / total_batches * 100
            print(f"Epoch {epoch+1}/{num_epochs}, Batch {batch_idx+1}/{total_batches} - {percent_trained:.2f}% of training set trained")

        epoch_loss = running_loss / len(dataloader.dataset)
        train_accuracy = calculate_accuracy(model, dataloader)
        valid_accuracy = calculate_accuracy(model, valid_dataloader)

        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}, Train Accuracy: {train_accuracy:.4f}, Valid Accuracy: {valid_accuracy:.4f}")


In [None]:
start_time = time.time()
for batch_idx, (inputs, labels) in enumerate(dataloader):
    print(f"Batch {batch_idx+1}, Time taken: {time.time() - start_time:.2f} seconds")
    start_time = time.time()

In [19]:
# Train the model
train_model(model, dataloader, valid_dataloader, criterion, optimizer, num_epochs=10)

  0%|          | 0/10 [00:00<?, ?it/s]

: 

: 

In [None]:
# SAVE MODEL
torch.save(model.state_dict(), r"trained_models\first_attempt_cnn2d.pth")

In [None]:
#LOAD MODEL

# Reinitialize the model
loaded_model = CNN2D(num_classes=2).to(device)

# Load the saved model weights
loaded_model.load_state_dict(torch.load("trained_model.pth", map_location=device))

# Set the model to evaluation mode
loaded_model.eval()

print("Model successfully loaded and ready for inference.")


In [None]:
# SHAP explainability
model.eval()

# Select a batch of images to explain
images, _ = next(iter(dataloader))
images = images.to(device)

# Define a SHAP explainer
def predict(images):
    with torch.no_grad():
        logits = model(images)
        probs = nn.Softmax(dim=1)(logits)
    return probs.cpu().numpy()

explainer = shap.DeepExplainer(model, images)
shap_values = explainer.shap_values(images)

# Visualize SHAP explanations
shap.image_plot(shap_values, images.cpu().numpy())