In [1]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
from torchviz import make_dot
from torchinfo import summary
import torch.nn.functional as F
from sklearn.metrics import confusion_matrix, roc_curve, auc, precision_recall_curve
import seaborn as sns

In [2]:
from torchvision import models

In [3]:
class ImageModel(nn.Module):
    def __init__(self):
        super(ImageModel, self).__init__()
        self.model = models.efficientnet_b0(pretrained=True)  # Load pretrained EfficientNet-B0
        self.model.classifier[1] = nn.Linear(self.model.classifier[1].in_features, 128)  # Adjust output layer
    
    def forward(self, x):
        return self.model(x)

In [4]:
class AttentionLayer(nn.Module):
    def __init__(self, input_dim):
        super(AttentionLayer, self).__init__()
        self.W = nn.Linear(input_dim, input_dim)
    
    def forward(self, x):
        attn_weights = F.softmax(self.W(x), dim=1)
        return x * attn_weights

In [5]:
class MetadataModel(nn.Module):
    def __init__(self, input_size):
        super(MetadataModel, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_size, 128),
            nn.ReLU(),
            nn.LayerNorm(128),
            AttentionLayer(128),  # Apply attention to highlight key features
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.LayerNorm(64),
            nn.Linear(64, 32)
        )
    
    def forward(self, x):
        return self.fc(x)

In [6]:
# Define Multimodal Model
class MultimodalModel(nn.Module):
    def __init__(self, image_model, metadata_model):
        super(MultimodalModel, self).__init__()
        self.image_model = image_model
        self.metadata_model = metadata_model
        self.classifier = nn.Linear(128 + 32, 2)  # Combining both feature sets
    
    def forward(self, image, metadata):
        img_features = self.image_model(image)
        meta_features = self.metadata_model(metadata)
        combined = torch.cat((img_features, meta_features), dim=1)
        return self.classifier(combined)

In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
image_model = ImageModel().to(device)
metadata_model = MetadataModel(input_size=7).to(device)
model = MultimodalModel(image_model, metadata_model).to(device)



In [8]:
# Load saved checkpoint
checkpoint_path = "C:/Data/DJ/SkinCancer/code/pytorch_models/mra_midas_efficientnetB0.pth"
checkpoint = torch.load(checkpoint_path, map_location=device)

model.load_state_dict(checkpoint['model_state_dict'])
model.eval()  # Set to evaluation mode

print("Model loaded successfully!")

Model loaded successfully!


  checkpoint = torch.load(checkpoint_path, map_location=device)


In [16]:
from torchviz import make_dot

# Get a sample input
sample_image = torch.randn(1, 3, 224, 224).to(device)
sample_metadata = torch.randn(1, 7).to(device)
sample_output = model(sample_image, sample_metadata)

# Generate visualization
dot = make_dot(sample_output, params=dict(model.named_parameters()))
dot.render("model_visualization", format="png")  # Save as PNG

'model_visualization.png'

In [23]:
from torchviz import make_dot

# Get a sample input
sample_image = torch.randn(1, 3, 224, 224).to(device)
sample_metadata = torch.randn(1, 7).to(device)

# Extract the final layer outputs instead of the entire model
img_features = model.image_model(sample_image)
meta_features = model.metadata_model(sample_metadata)
combined_features = torch.cat((img_features, meta_features), dim=1)
final_output = model.classifier(combined_features)

# Generate visualization (only for the classifier)
dot = make_dot(final_output, params=dict(model.classifier.named_parameters()))
dot.render("model_visualization_simplified-1", format="png")


'model_visualization_simplified-1.png'

In [None]:
import torch
from torchviz import make_dot

# Get a sample input
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
sample_image = torch.randn(1, 3, 224, 224).to(device)
sample_metadata = torch.randn(1, 7).to(device)

# Extract features separately
img_features = model.image_model(sample_image)
meta_features = model.metadata_model(sample_metadata)

# Generate individual graphs
dot_img = make_dot(img_features, params=dict(model.image_model.named_parameters()))
dot_meta = make_dot(meta_features, params=dict(model.metadata_model.named_parameters()))

# Assign colors
for node in dot_img.body:
    node = node.replace('fillcolor=black', 'fillcolor=lightblue')

for node in dot_meta.body:
    node = node.replace('fillcolor=black', 'fillcolor=lightgreen')

# Combine features
combined_features = torch.cat((img_features, meta_features), dim=1)
final_output = model.classifier(combined_features)

# Generate final visualization
dot_final = make_dot(final_output, params=dict(model.classifier.named_parameters()))

# Merge color changes
dot_final.body.extend(dot_img.body)
dot_final.body.extend(dot_meta.body)

# Save and render
dot_final.render("model_visualization_simplified_colored", format="png")


'model_visualization_simplified_colored.png'

In [21]:
for param in model.image_model.model.features.parameters():
    param.requires_grad = False  # Freeze feature extractor

dot = make_dot(sample_output, params=dict(model.classifier.named_parameters()))
dot.render("model_visualization_small-2", format="png")


'model_visualization_small-2.png'

In [22]:
dot = make_dot(sample_output, params={name: p for name, p in model.named_parameters() if 'classifier' in name or 'fc' in name})
dot.render("model_visualization_filtered-3", format="png")

'model_visualization_filtered-3.png'

In [4]:
!pip install graphviz





[notice] A new release of pip is available: 24.2 -> 25.0.1
[notice] To update, run: python.exe -m pip install --upgrade pip


In [17]:
torch.save(model, "model.pth")

In [18]:
!pip install netron

Collecting netron
  Downloading netron-8.1.9-py3-none-any.whl.metadata (1.5 kB)
Downloading netron-8.1.9-py3-none-any.whl (1.9 MB)
   ---------------------------------------- 0.0/1.9 MB ? eta -:--:--
   ----- ---------------------------------- 0.3/1.9 MB ? eta -:--:--
   ----- ---------------------------------- 0.3/1.9 MB ? eta -:--:--
   ---------------- ----------------------- 0.8/1.9 MB 1.1 MB/s eta 0:00:02
   --------------------- ------------------ 1.0/1.9 MB 1.3 MB/s eta 0:00:01
   --------------------------- ------------ 1.3/1.9 MB 1.4 MB/s eta 0:00:01
   -------------------------------- ------- 1.6/1.9 MB 1.3 MB/s eta 0:00:01
   ---------------------------------------- 1.9/1.9 MB 1.4 MB/s eta 0:00:00
Installing collected packages: netron
Successfully installed netron-8.1.9



[notice] A new release of pip is available: 24.2 -> 25.0.1
[notice] To update, run: python.exe -m pip install --upgrade pip


In [24]:
!netron model.pth

^C
