Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Juelich 2024 #32

Closed
wants to merge 194 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
194 commits
Select commit Hold shift + click to select a range
cbd3b40
add and modify colab section
donatella-cea Oct 11, 2023
c0d6a81
add first version LIME tutorial
donatella-cea Oct 11, 2023
ec6d9a9
attention map on images
sab148 Oct 12, 2023
b5d0fdf
attnetion tutorial
sab148 Oct 27, 2023
b747933
fixing figure
sab148 Oct 27, 2023
ebe8ffb
add the image
sab148 Oct 27, 2023
9019e79
add model description, address warning on model weights
donatella-cea Oct 30, 2023
e3cde65
more explanations
sab148 Oct 30, 2023
e5a6ed7
Merge branch 'Juelich-2023' of https://github.com/HelmholtzAI-Consult…
sab148 Oct 30, 2023
b7cd55a
move batch_predict to utils
donatella-cea Oct 30, 2023
40bb320
Merge branch 'Juelich-2023' of https://github.com/HelmholtzAI-Consult…
donatella-cea Oct 30, 2023
b9f829c
arrange image
sab148 Oct 30, 2023
c59d4af
Merge branch 'Juelich-2023' of https://github.com/HelmholtzAI-Consult…
sab148 Oct 30, 2023
7a7681d
add image
sab148 Oct 30, 2023
d75778d
add image
sab148 Oct 30, 2023
0628cb7
add pred probability top 5 classes
donatella-cea Oct 30, 2023
8498e1f
cam for signals
sab148 Nov 8, 2023
0caaf95
Merge branch 'Juelich-2023' of https://github.com/HelmholtzAI-Consult…
sab148 Nov 8, 2023
6f57626
first draft cam tutorial
sab148 Nov 8, 2023
9b7c070
render changes
sab148 Nov 8, 2023
546e683
description changes
sab148 Nov 8, 2023
f92417c
ahh image and fix the math markdown
sab148 Nov 8, 2023
a83bf4c
fix math markdown
sab148 Nov 8, 2023
aba13d9
arrange the structure
sab148 Nov 8, 2023
f417d6e
add colab section
donatella-cea Nov 10, 2023
033219e
add colab section
donatella-cea Nov 10, 2023
c4f6675
tutorial transformers
sab148 Nov 12, 2023
535ff31
Merge branch 'Juelich-2023' of https://github.com/HelmholtzAI-Consult…
sab148 Nov 12, 2023
578469a
tutorial transformers
sab148 Nov 13, 2023
3ecd8a4
improved version
sab148 Nov 13, 2023
83bab05
added one more sub section
sab148 Nov 13, 2023
1d3b8b5
tutorial vit
sab148 Nov 13, 2023
f3c60d1
add vit image
sab148 Nov 13, 2023
0a1f83a
change encoder part
sab148 Nov 13, 2023
f099529
add theoretical details
donatella-cea Nov 13, 2023
c29ebe6
finish lime for images, add lime to comparison notebook
donatella-cea Nov 14, 2023
4bd0882
add the cell to download the dataset
sab148 Nov 16, 2023
931cc84
remove warnings
sab148 Nov 16, 2023
9f24720
cam for each class
sab148 Nov 16, 2023
3914db6
remove warnings
sab148 Nov 16, 2023
9721bc4
change weight name, tutorial attention text
sab148 Nov 17, 2023
ab78d99
finished attention tutorial
sab148 Nov 20, 2023
8b1401e
Francesco suggestions replied
sab148 Nov 21, 2023
9442510
typos corrections
sab148 Nov 21, 2023
1b73185
more typos corrections
sab148 Nov 21, 2023
3828ed7
grad-cam vs cam
sab148 Nov 21, 2023
43b55f2
fransesco suggestions on cam
sab148 Nov 22, 2023
c5c5df7
ecg training and correct typos
sab148 Nov 22, 2023
ee14046
restructure repo, add colab sections
donatella-cea Nov 23, 2023
c24a69c
update colab link
donatella-cea Nov 23, 2023
6847f1a
dona suggestions
sab148 Nov 23, 2023
44d2352
update requirements file
sab148 Nov 23, 2023
d82a9d4
upload weights
sab148 Nov 24, 2023
3aa8f10
fixed image
sab148 Nov 24, 2023
d5fdb91
fixe more images
sab148 Nov 24, 2023
4da9797
fiz image sizes for collab
sab148 Nov 24, 2023
59a2eb5
fix embedding image size
sab148 Nov 24, 2023
e3f753d
colab link
donatella-cea Nov 24, 2023
61177dd
typos
donatella-cea Nov 24, 2023
ca1ad5c
fixed postion of code
sab148 Nov 24, 2023
aaa0ce4
add extra session, add seeds
donatella-cea Nov 24, 2023
c957cd6
add exercise SHAP
donatella-cea Nov 28, 2023
173f6b9
Update README.md
donatella-cea Nov 28, 2023
4f5c80f
changed code to add residuals
sab148 Nov 29, 2023
bf2e37d
Merge branch 'Juelich-2023' of https://github.com/HelmholtzAI-Consult…
sab148 Nov 29, 2023
8953cc2
add reference
sab148 Nov 29, 2023
d902e66
more changes to materials
sab148 Nov 29, 2023
2dd4b79
correct typo, add explanations
donatella-cea Nov 29, 2023
b9d897b
change example
sab148 Nov 30, 2023
21ed206
Update Tutorial_LIME_Images.ipynb
neuronflow Nov 30, 2023
cf1adc3
correct typo
donatella-cea Dec 1, 2023
8daee04
Merge pull request #8 from HelmholtzAI-Consultants-Munich/neuronflow-…
donatella-cea Dec 1, 2023
66050fd
remove solutions
donatella-cea Dec 1, 2023
0448130
remove answers
donatella-cea Dec 1, 2023
4476b5d
remove answers
donatella-cea Dec 1, 2023
3a0f972
remove answers
donatella-cea Dec 1, 2023
11c6241
remove answers
donatella-cea Dec 1, 2023
80ce3b6
remove answers
donatella-cea Dec 1, 2023
b67541e
remove answers
donatella-cea Dec 1, 2023
20c8e52
remove answers
donatella-cea Dec 1, 2023
2ccf4c4
remove answers, clear output
donatella-cea Dec 1, 2023
41fa3a2
Update README.md requirement section
donatella-cea Dec 1, 2023
6736d9e
cam explanations
sab148 Dec 3, 2023
c290d36
to markdown
sab148 Dec 3, 2023
294ba98
fix math in the text
sab148 Dec 3, 2023
28c9868
fix math formulas
sab148 Dec 3, 2023
74361f2
fix spacing
sab148 Dec 3, 2023
ef80a86
add modifications for attetion maps
sab148 Dec 3, 2023
349a78b
ad new image
sab148 Dec 3, 2023
8f662d2
Update Tutorial_SHAP.ipynb
neuronflow Dec 4, 2023
6ecd225
Merge pull request #9 from HelmholtzAI-Consultants-Munich/neuronflow-…
neuronflow Dec 4, 2023
d5db9fc
Update Tutorial_Grad-CAM.ipynb
neuronflow Dec 4, 2023
6d4343f
Merge pull request #10 from HelmholtzAI-Consultants-Munich/neuronflow…
neuronflow Dec 4, 2023
968b9f7
Update Tutorial_SHAP.ipynb
neuronflow Dec 4, 2023
7633560
Merge pull request #14 from HelmholtzAI-Consultants-Munich/neuronflow…
neuronflow Dec 4, 2023
0a3a9fc
Update Tutorial_PermutationFeatureImportance.ipynb
neuronflow Dec 4, 2023
20ba57d
Add sklearn to requirements
IsraMekki0 Dec 4, 2023
269ae65
Update Tutorial_PermutationFeatureImportance.ipynb
neuronflow Dec 4, 2023
41da5d1
Merge pull request #15 from HelmholtzAI-Consultants-Munich/neuronflow…
neuronflow Dec 4, 2023
c0ae17b
Update Tutorial_SHAP.ipynb
neuronflow Dec 4, 2023
ebdabc8
Merge pull request #16 from HelmholtzAI-Consultants-Munich/neuronflow…
neuronflow Dec 4, 2023
9e5de20
Update requirements.txt
IsraMekki0 Dec 4, 2023
0aa9cd2
Update Tutorial_LIME.ipynb
neuronflow Dec 4, 2023
b4f0a78
Merge pull request #17 from HelmholtzAI-Consultants-Munich/neuronflow…
neuronflow Dec 4, 2023
d35295b
Update Tutorial_LIME.ipynb
neuronflow Dec 4, 2023
38d1484
Merge pull request #18 from HelmholtzAI-Consultants-Munich/neuronflow…
neuronflow Dec 4, 2023
f53da38
Update Tutorial_FGC.ipynb
neuronflow Dec 4, 2023
bdc21f7
Merge pull request #19 from HelmholtzAI-Consultants-Munich/neuronflow…
neuronflow Dec 4, 2023
153e15f
lang
neuronflow Dec 4, 2023
389a527
Update Tutorial_Grad-CAM.ipynb
neuronflow Dec 4, 2023
ee4d1e7
Update Tutorial_Grad-CAM.ipynb
neuronflow Dec 4, 2023
aa1b08e
Update Tutorial_SHAP_Images.ipynb
neuronflow Dec 4, 2023
9a266d7
Merge pull request #21 from HelmholtzAI-Consultants-Munich/neuronflow…
neuronflow Dec 4, 2023
c905430
Update Tutorial_LIME_Images.ipynb
neuronflow Dec 4, 2023
96e9238
Merge pull request #22 from HelmholtzAI-Consultants-Munich/neuronflow…
neuronflow Dec 4, 2023
36e2009
Update Tutorial_XAI_for_ImageAnalysis.ipynb
neuronflow Dec 4, 2023
2c6e1fd
Merge pull request #23 from HelmholtzAI-Consultants-Munich/neuronflow…
neuronflow Dec 4, 2023
3017b1b
Update README.md
neuronflow Dec 4, 2023
b441050
Merge pull request #24 from HelmholtzAI-Consultants-Munich/neuronflow…
neuronflow Dec 4, 2023
3fd019d
Update Model-Transformers.ipynb
neuronflow Dec 4, 2023
195f9ef
Merge pull request #25 from HelmholtzAI-Consultants-Munich/neuronflow…
neuronflow Dec 4, 2023
6e0455a
Update Tutorial_attention_map_for_text.ipynb
neuronflow Dec 4, 2023
1a07dfa
Merge pull request #26 from HelmholtzAI-Consultants-Munich/neuronflow…
neuronflow Dec 4, 2023
c82db37
Update Model-VIT.ipynb
neuronflow Dec 4, 2023
8572fc0
Merge pull request #27 from HelmholtzAI-Consultants-Munich/neuronflow…
neuronflow Dec 4, 2023
b4d87df
language fixes
neuronflow Dec 4, 2023
a9485d7
Merge pull request #28 from HelmholtzAI-Consultants-Munich/language_a…
neuronflow Dec 4, 2023
fc3bdd9
lang
neuronflow Dec 4, 2023
5a1e278
Merge pull request #30 from HelmholtzAI-Consultants-Munich/lang_cam_s…
neuronflow Dec 4, 2023
ce1566e
Merge pull request #20 from HelmholtzAI-Consultants-Munich/feature_vi…
sab148 Dec 4, 2023
4414da6
resolve merge conflict
lisa-sousa Dec 4, 2023
05601be
Merge pull request #13 from HelmholtzAI-Consultants-Munich/neuronflow…
lisa-sousa Dec 4, 2023
79f4fef
add intro slides
donatella-cea Dec 4, 2023
871eb5d
add lime installation
donatella-cea Dec 5, 2023
d7aed86
Changed normal examples
IsraMekki0 Dec 5, 2023
a7112e8
Corrected previous commit
IsraMekki0 Dec 5, 2023
bb23814
add weights
donatella-cea Dec 6, 2023
de2a112
rename weights file
donatella-cea Dec 6, 2023
82b7938
fix weights path in attention maps for text
sab148 Dec 6, 2023
ce057b8
modify link to last image
donatella-cea Dec 6, 2023
cc5bd53
modify link to last image
donatella-cea Dec 6, 2023
391b2d4
add answers
donatella-cea Dec 8, 2023
d5bd483
add answers
donatella-cea Dec 8, 2023
df93cde
add answers
donatella-cea Dec 8, 2023
0eb9e50
add answers
donatella-cea Dec 8, 2023
17a68c1
add answers
donatella-cea Dec 8, 2023
92e854f
add answers
donatella-cea Dec 8, 2023
c9c3013
add answers
donatella-cea Dec 8, 2023
015dde5
add answers
donatella-cea Dec 8, 2023
38869d8
add answers
donatella-cea Dec 8, 2023
0cc3614
restructure repo
donatella-cea Apr 12, 2024
67542c1
Delete Introduction_to_XAI_Juelich_2023.pdf
donatella-cea Apr 12, 2024
6fbfedc
Update README.md
donatella-cea Apr 12, 2024
d5ec70d
Update README.md
donatella-cea Apr 26, 2024
4139f52
grad cam for signals
sab148 Apr 29, 2024
8d08b1f
remove cam tutorial
sab148 Apr 29, 2024
f86bd66
change name of tutorial
sab148 Apr 29, 2024
3d93bfb
update colab link
donatella-cea Apr 30, 2024
814d873
update colab link
donatella-cea Apr 30, 2024
90aa883
update colab link
donatella-cea Apr 30, 2024
88595ad
change colab link
donatella-cea May 2, 2024
c9a9d9f
change colab link
donatella-cea May 2, 2024
243f2ab
change colab link
donatella-cea May 2, 2024
92d8d9c
change colab link
donatella-cea May 2, 2024
8bb83bc
change colab link
donatella-cea May 2, 2024
81331b6
change colab link
donatella-cea May 2, 2024
4433c1b
change colab link
donatella-cea May 2, 2024
0a60237
change colab link
donatella-cea May 2, 2024
77bdc3b
change colab link
donatella-cea May 2, 2024
c3f023c
change colab link
donatella-cea May 2, 2024
b70991d
change colab link
donatella-cea May 2, 2024
81b6397
change colab link
donatella-cea May 2, 2024
70b828a
change colab link
donatella-cea May 2, 2024
89fa53a
change colab link
donatella-cea May 2, 2024
bf47c4e
rename notebook, change colab link
donatella-cea May 2, 2024
fde88f5
Update README.md
donatella-cea May 2, 2024
1d5d3c3
add gradient explainer explanation
donatella-cea May 3, 2024
61170ba
moved transformer theory to rtd
lisa-sousa May 3, 2024
f101919
add GradientExplainer explanation
donatella-cea May 3, 2024
5753839
t5 weights for German
sab148 May 3, 2024
3751905
ECG test data
sab148 May 3, 2024
784391d
Corrected paths
IsraMekki0 May 6, 2024
7987f27
fig bug in get_prediction() and remove folder
sab148 May 7, 2024
95dc2df
Merge branch 'Juelich-2024' of https://github.com/HelmholtzAI-Consult…
sab148 May 7, 2024
1fff27f
add test data for signals
sab148 May 7, 2024
5c6253d
change location of test data
sab148 May 7, 2024
d2d7b35
german english weights
sab148 May 7, 2024
3498b86
remove weights
sab148 May 7, 2024
225b549
training transformer for english german translation
sab148 May 7, 2024
5ccbe24
change weights name
sab148 May 7, 2024
9803e3f
zip location changed
sab148 May 7, 2024
dd888bc
Merge branch 'Juelich-2024' of github.com:HelmholtzAI-Consultants-Mun…
IsraMekki0 May 8, 2024
d20b917
Download model weights from google drive
IsraMekki0 May 8, 2024
c6fc23d
Merge pull request #31 from HelmholtzAI-Consultants-Munich/Juelich-20…
sab148 May 8, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 69 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
# Tutorials for eXplainable Artificial Intelligence (XAI) methods

This repository contains a collection of self-explanatory tutorials for different model-agnostic and model-specific XAI methods.
Each tutorial comes in a Jupyter Notebook containing a short video lecture and practical exercises.
The material has already been used in the context of two courses: the Zero to Hero Summer Academy (fully online) and ml4hearth (hybrid setting).
Each tutorial comes in a Jupyter Notebook, containing a short video lecture and practical exercises.
The material has already been used in the context of two courses: the Helmholtz Summer Academy 2022 and 2023 (fully online) and ml4hearth (hybrid setting), MALTAomics Summer School (hybrid setting).
The course material can be adjusted according to the available time frame and the schedule.
The material is self-explanatory and can also be consumed offline.

Expand All @@ -17,23 +17,81 @@ The learning objectives are:
- understand the importance of interpretability
- discover the existing model-agnostic and model-specific XAI methods
- learn how to interpret the outputs and graphs of those methods with hands-on exercises
- learn to choose which method is suitable for a specific task
- learn to chose which method is suitable for a specific task

## List of Tutorials for Model-Agnostic Methods
## Venue
The course will be fully online:
*Add zoom link*

- Permutation Feature Importance
- SHapley Additive exPlanations (SHAP)
- Local Interpretable Model-Agnostic Explanations (LIME)
Link to the share notes: https://notes.desy.de/HAmuRdemQgK8VW9mqjBBPQ?both

## List of Tutorials for Model-Specific Methods
## Schedule at a glance

- Forest-Guided Clustering
- Grad-CAM
#### Day 1 - XAI for Random Forest
| Time | Session | Duration |
|---|---|---|
|9:00 - 9:30 |Introduction |30 min|
|9:30 - 10:15 | Permutation Feature Importance| 45 min|
|10:15 - 10:30 | Break| 15 min|
| 10:30 - 11:30 | SHAP | 1 h|
|11:30 - 11:45 | Break| 10 min|
|11:45- 12:15 | LIME | 30 min|
|12:15 - 12:55 | FGC |40 min|
|12:55 - 13:00 | Conclusions |5 min|

## Requirements and Setup
Homework 1: Comparison notebook - [Tutorial_XAI_for_RandomForest](https://github.com/HelmholtzAI-Consultants-Munich/XAI-Tutorials/blob/Juelich-2024/xai-for-tabular-data/Tutorial_XAI_for_RandomForests.ipynb)

Homework 2: SHAP exercise - [Compute Shapley values by hand](https://github.com/HelmholtzAI-Consultants-Munich/XAI-Tutorials/blob/Juelich-2024/SHAP_exercise.pdf)


#### Day 2 - XAI for CNNs
| Time | Session | Duration |
|---|---|---|
|9:00 - 9:15 | Welcome |15 min|
|9:15 - 9:30 | Homework Discussion| 15 min|
|9:30 - 10:00 | Intro CNNs| 30 min|
|10:00 - 10:15 | Break | 15 min|
|10:15 - 11:05 | Grad-CAM for Images| 50 min|
|11:05- 11:45 | Grad-CAM for Signals | 40 min|
|11:45 - 12:00 | Break | 15 min|
|12:00 - 12:30 | LIME for Images | 30 min|
|12:30 - 12:55| SHAP for Images | 25 min |
|12:55 - 13:00 | Conclusions |5 min|

Homework 1: Comparison notebook - [Tutorial_XAI_for_ImageAnalysis](https://github.com/HelmholtzAI-Consultants-Munich/XAI-Tutorials/blob/Juelich-2024/xai-for-image-data/Tutorial_XAI_for_ImageAnalysis.ipynb)


#### Day 3 - XAI for Transformers
| Time | Session | Duration |
|---|---|---|
|9:00 - 9:15 |Welcome |15 min|
|9:15 - 9:30 | Homework Discussion| 15 min|
|9:30 - 10:15 | Intro to trasformers | 45 min|
|10:15 - 10:30 | Break | 15 min|
|10:30 - 11:00 | Attention for text | 45 min|
|11:00 - 11:30 | Intro to Vision Transformers | 30 min|
|11:30 - 11:45 | Break | 15 min|
|11:45 - 12:45 | Attention map for image transformers | 60 min|
|12:45 - 13:00 | Conclusions & Survey | 15 min|


## Requirements and Setup - *Check this section*

It is possible to either create an environment and install all the necessary packages locally (using the requirements.txt file) or to execute the notebooks on the browser, by clicking the 'Open in Colab' button. This second option doesn't require any further installation, but the user must have access to a Google account.

If you prefer to run the notebooks on your device, create a virtual environment using the requirements.txt file:
```
conda create -n XAI-Course-2024 python=3.9
conda activate XAI-Course-2024
pip install -r requirements.txt
```

Once your environment is created, clone `Juelich-2024` brach branch of the repo using the following command:

```
git clone --branch Juelich-2024 https://github.com/HelmholtzAI-Consultants-Munich/XAI-Tutorials.git
```

## Contributions

Comments and input are very welcome! If you have a suggestion or think something should be changed, please open an issue or submit a pull request.
Binary file added SHAP_exercise.pdf
Binary file not shown.
23 changes: 23 additions & 0 deletions data_and_models/ECG/ECG.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#%%
import torch
import torch.utils.data as data
import pandas as pd
import numpy as np

class ECG(data.Dataset):
def __init__(self, data):
self.data = pd.read_csv(data, sep=',', header=None)
self.samples = self.data.iloc[:, :187]
self.targets = self.data[187].to_numpy()


def __getitem__(self, index):
x = self.samples.iloc[index, :]
x = torch.from_numpy(x.values).float()
x = torch.unsqueeze(x, 0)
y = self.targets[index].astype(np.int64)
return x, y

def __len__(self):
return len(self.data)

Binary file added data_and_models/ECG/ECG_test_data.zip
Binary file not shown.
93 changes: 93 additions & 0 deletions data_and_models/ECG/ResNet1D.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
#%%
import torch
import torch.nn as nn

class ResNetBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride=1):
super(ResNetBlock, self).__init__()
self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm1d(out_channels)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm1d(out_channels)
self.stride = stride

self.shortcut = nn.Sequential()
if stride != 1 or in_channels != out_channels:
self.shortcut = nn.Sequential(
nn.Conv1d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm1d(out_channels)
)

def forward(self, x):
residual = x

out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)

out = self.conv2(out)
out = self.bn2(out)

out += self.shortcut(residual)
out = self.relu(out)

return out

class ResNet(nn.Module):
def __init__(self, block, layers, num_classes=10):
super(ResNet, self).__init__()
self.in_channels = 64

self.conv1 = nn.Conv1d(1, 64, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm1d(64)
self.relu = nn.ReLU(inplace=True)
self.layer1 = self.make_layer(block, 64, layers[0])
self.layer2 = self.make_layer(block, 128, layers[1], stride=2)
self.layer3 = self.make_layer(block, 256, layers[2], stride=2)
self.layer4 = self.make_layer(block, 512, layers[3], stride=2)
self.avg_pool = nn.AdaptiveAvgPool1d((1,))
self.fc = nn.Linear(512, num_classes)
self.gradient = None

def make_layer(self, block, out_channels, blocks, stride=1):
layers = []
layers.append(block(self.in_channels, out_channels, stride))
self.in_channels = out_channels
for i in range(1, blocks):
layers.append(block(out_channels, out_channels))
return nn.Sequential(*layers)

# hook for the gradients
def activations_hook(self, grad):
self.gradient = grad

def get_gradient(self):
return self.gradient

def get_activations(self, x):
return self.features(x)

def forward(self, x, label=None, return_cam=False):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)

x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x.register_hook(self.activations_hook)

pre_logits = self.avg_pool(x)
pre_logits = torch.flatten(pre_logits, 1)
logits = self.fc(pre_logits)

if return_cam:
feature_map = x.detach().clone()
cam_weights = self.fc.weight[label]
cams = (cam_weights.view(*feature_map.shape[:2], 1) *
feature_map).mean(1, keepdim=False)
return logits, cams

return logits, x
110 changes: 110 additions & 0 deletions data_and_models/ECG/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader

from ResNet1D import ResNet, ResNetBlock
from ECG import ECG



def training_loop(model, train_loader, criterion, optimizer, num_epochs, writer, device):

for epoch in range(num_epochs):
running_loss = 0.0
for i, data in enumerate(train_loader, 0):
inputs, labels = data
optimizer.zero_grad()
outputs = model(inputs.to(device))
loss = criterion(outputs[0], labels.to(device))
loss.backward()
optimizer.step()
running_loss += loss.item()
if i % 2000 == 1999:
print('[%d, %5d] loss: %.3f' %
(epoch + 1, i + 1, running_loss / 2000))

torch.save(model.state_dict(), 'weights/model_weights'+str(epoch)+'.pth')

writer.add_scalar('training loss', running_loss/len(train_loader), epoch)

torch.save(model.state_dict(), 'weights/model_final_weights_ecg.pth')

def inference_loop(model, test_loader, criterion):


model.eval()

device = next(model.parameters()).device
true_positives = [0] * 5
false_positives = [0] * 5
false_negatives = [0] * 5
true_negatives = [0] * 5
total_loss = 0
total_samples = 0

with torch.no_grad():
for inputs, targets in test_loader:
targets = targets.to(device)
inputs = inputs.to(device)
outputs = model(inputs)
outputs = outputs[0]

loss = criterion(outputs, targets)

total_loss += loss.item()
total_samples += inputs.size(0)
_, predictions = torch.max(outputs, -1)

for i in range(5):

true_positives[i] += ((predictions == i) & (targets == i)).sum().item()
false_positives[i] += ((predictions == i) & (targets != i)).sum().item()
false_negatives[i] += ((predictions != i) & (targets == i)).sum().item()
true_negatives[i] += ((predictions != i) & (targets != i)).sum().item()

accuracy = (sum(true_positives) + sum(true_negatives)) / (sum(true_positives) + sum(false_positives) + sum(false_negatives) + sum(true_negatives))
recall = [true_positives[i] / (true_positives[i] + false_negatives[i]) if true_positives[i] + false_negatives[i] > 0 else 0 for i in range(5)]
precision = [true_positives[i] / (true_positives[i] + false_positives[i]) if true_positives[i] + false_positives[i] > 0 else 0 for i in range(5)]
f1 = [2 * precision[i] * recall[i] / (precision[i] + recall[i]) if precision[i] + recall[i] > 0 else 0 for i in range(5)]
average_loss = total_loss / total_samples

return accuracy, recall, precision, f1, average_loss


if __name__ == "__main__":

model_weights_path = 'weights/model_final_weights_ecg.pth'
train_data_path = 'data/Dataset_ECG/mitbih_train.csv'
test_data_path = 'data/Dataset_ECG/mitbih_test.csv'
num_epochs = 10

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = ResNet(ResNetBlock, [2, 2, 2, 2], num_classes=5)
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

ecg = ECG(train_data_path)
train_loader = DataLoader(ecg, batch_size=64, shuffle=True, num_workers=4)

ecg = ECG(test_data_path)
test_loader = DataLoader(ecg, batch_size=64, shuffle=True, num_workers=4)

writer = SummaryWriter()

# training_loop(model, train_loader, criterion, optimizer, num_epochs, writer, device)

model.load_state_dict(torch.load(model_weights_path))

accuracy, recall, precision, f1, average_loss = inference_loop(model, test_loader, criterion)

print("Accuracy: ", accuracy)
print("Recall: ", recall)
print("Precision: ", precision)
print("F1: ", f1)
print("Average loss: ", average_loss)

Binary file not shown.
Loading
Loading