In [1]:
"""
This tutorial walks through the steps of training a ResNet50 model for Fashion MNIST dataset and analyzing the results with GTDA.
Please first setup your environment using the included GTDA.yml file in this repo or manually.
"""

import torch
from torchmetrics import Accuracy
import torchvision.models as torch_models
import torch.nn as nn
import pytorch_lightning as pl
from torch.optim.lr_scheduler import StepLR
from torchvision.datasets import FashionMNIST
from torchvision import transforms
from pytorch_lightning import Trainer,seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from argparse import Namespace

## Define ResNet18 model

In [2]:
"""
We directly used the ImageNet pretrained ResNet50 model from pytorch as a backbone to train our own model. 
Since Fashion MNIST dataset only has 10 different classes, we need to replace the last fully connected layer.
We also need to add definitions of "training_step", "validation_step", "test_step" and "configure_optimizers" in order to use pytorch lightning framework.
"""
class MyResNet(pl.LightningModule):
    def __init__(self, args, in_channels=1):
        super().__init__()
        self.args = args
        self.criterion = torch.nn.CrossEntropyLoss()
        self.accuracy = Accuracy()
        self.model = torch_models.resnet18(pretrained=True)
        # Change the input layer to take Grayscale image, instead of RGB images. 
        # Hence in_channels is set as 1
        # original definition of the first layer on the ResNet class
        # self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.model.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
        num_filters = self.model.fc.in_features
        self.model.fc = nn.Linear(num_filters, args.num_classes)

    def forward(self, batch):
        images, labels = batch
        predictions = self.model(images)
        loss = self.criterion(predictions, labels)
        accuracy = self.accuracy(predictions, labels)
        return loss, accuracy * 100

    def training_step(self, batch, batch_nb):
        loss, accuracy = self.forward(batch)
        self.log("loss/train", loss, on_step=False, on_epoch=True)
        self.log("acc/train", accuracy, on_step=False, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_nb):
        loss, accuracy = self.forward(batch)
        self.log("loss/val", loss, on_step=False, on_epoch=True)
        self.log("acc/val", accuracy, on_step=False, on_epoch=True)

    def test_step(self, batch, batch_nb):
        loss, accuracy = self.forward(batch)
        self.log("acc/test", accuracy, on_step=False, on_epoch=True)

    def configure_optimizers(self):
        parameters = self.model.parameters()
        optimizer = torch.optim.Adam(parameters, lr=self.args.learning_rate)
        return optimizer
    # def configure_optimizers(self):
    #     if self.args.last_layer_only:
    #         parameters = self.classifier.parameters()
    #     else:
    #         parameters = self.model.parameters()
    #     optimizer = torch.optim.SGD(
    #         parameters,
    #         lr=self.args.learning_rate,
    #         weight_decay=self.args.weight_decay,
    #         momentum=0.9,
    #         nesterov=True,
    #     )
    #     scheduler = {
    #         "scheduler": StepLR(
    #             optimizer, step_size=self.args.lr_step_size, gamma=self.args.lr_gamma,
    #         ),
    #         "interval": "epoch",
    #         "name": "learning_rate",
    #     }
    #     return [optimizer], [scheduler]

## Define Fashion MNIST dataset

In [3]:
class MyFashionMNIST(pl.LightningDataModule):
    def __init__(
        self, args, train_transform=None, test_transform=None):
        super().__init__()
        self.args = args
        self.train_dataset = FashionMNIST(root="../dataset/",train=True,download=True,transform=train_transform)
        self.test_dataset = FashionMNIST(root="../dataset/",train=False,download=True,transform=test_transform)

    def train_dataloader(self):
        dataloader = torch.utils.data.DataLoader(
            self.train_dataset,
            batch_size=self.args.batch_size,
            num_workers=self.args.num_workers,
            shuffle=self.args.shuffle,
            drop_last=self.args.drop_last,
            pin_memory=self.args.pin_memory,
        )
        return dataloader

    def val_dataloader(self):
        dataloader = torch.utils.data.DataLoader(
            self.test_dataset,
            batch_size=self.args.batch_size,
            num_workers=self.args.num_workers,
            drop_last=False,
            pin_memory=self.args.pin_memory,
            shuffle=False
        )
        return dataloader

    def test_dataloader(self):
        return self.val_dataloader()

In [4]:
args = {
    "batch_size": 256,
    "weight_decay": 1e-2,
    "learning_rate": 1e-3,
    "max_epochs": 100,
    "num_workers": 8,
    "num_classes": 10,
    "precision": 32,
    "gpu_id": 0,
    "shuffle": True,
    "drop_last": False,
    "pin_memory": True,
    "lr_warmup": 0.2,
    "lr_gamma": 0.1,
    "lr_step_size": 20,
}
args = Namespace(**args)

"""
For training transform, we use standard normalization and data augmentation.
"""
fashion_mnist = FashionMNIST(root="../dataset/",train=True,download=True).train_data.float()

train_transform = transforms.Compose([ transforms.Resize((112, 112)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.RandomAffine(degrees=20, translate=(0.1, 0.1)),
        transforms.ToTensor(), 
        transforms.Normalize((fashion_mnist.mean()/255,), (fashion_mnist.std()/255,))])

test_transform = transforms.Compose([ transforms.Resize((112, 112)),
        transforms.ToTensor(), 
        transforms.Normalize((fashion_mnist.mean()/255,), (fashion_mnist.std()/255,))])


# train_transform = transforms.Compose([ transforms.Resize((128, 128)),
#         transforms.RandomSizedCrop(112),
#         transforms.RandomHorizontalFlip(),
#         transforms.ToTensor(), 
#         transforms.Normalize((fashion_mnist.mean()/255,), (fashion_mnist.std()/255,))])

# test_transform = transforms.Compose([ transforms.Resize((112, 112)),
#         transforms.ToTensor(), 
#         transforms.Normalize((fashion_mnist.mean()/255,), (fashion_mnist.std()/255,))])

checkpoint = ModelCheckpoint(
    monitor="acc/val", mode="max", save_last=True)
seed_everything(42, workers=True)
lightning_model = MyResNet(args)
logger = TensorBoardLogger("FashionMNIST", name="resnet50")
trainer = Trainer(
    logger=logger,
    gpus=-1,
    deterministic=True,
    weights_summary=None,
    log_every_n_steps=1,
    max_epochs=args.max_epochs,
    callbacks=[checkpoint],
    precision=args.precision,
)
data = MyFashionMNIST(
    args,train_transform=train_transform,
    test_transform=test_transform)
assert(data.train_dataset.class_to_idx==data.test_dataset.class_to_idx)

Global seed set to 42
  rank_zero_deprecation(
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [11]:
# trainer.fit(lightning_model,data)
ckpt = torch.load("FashionMNIST/resnet50/version_14/checkpoints/last.ckpt")
lightning_model.load_state_dict(ckpt['state_dict'])

<All keys matched successfully>

In [12]:
trainer.test(model=lightning_model,dataloaders=data.test_dataloader())

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: 0it [00:00, ?it/s]

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'acc/test': 93.9800033569336}
--------------------------------------------------------------------------------


[{'acc/test': 93.9800033569336}]

In [13]:
import sys
sys.path.append("../")
import numpy as np
from GTDA.GTDA_utils import SPoC,normalize,knn_cuda_graph
from sklearn.decomposition import PCA
import scipy.sparse as sp
from GTDA.GTDA_utils import compute_reeb, NN_model
from GTDA.GTDA import GTDA

In [14]:
from tqdm import tqdm

cnn_model = lightning_model.model
cnn_model.eval()
args.shuffle = False
args.drop_last = False
data_orig = MyFashionMNIST(
    args,train_transform=test_transform,
    test_transform=test_transform)
trainset_orig = data_orig.train_dataset
testset_orig = data_orig.test_dataset
trainloader_orig = data_orig.train_dataloader()
testloader_orig = data_orig.test_dataloader()
train_nodes = list(range(len(trainset_orig)))
val_nodes = []
test_nodes = list(range(len(trainset_orig),len(trainset_orig)+len(testset_orig)))
_,y,preds_orig = SPoC(cnn_model,[trainloader_orig,testloader_orig],pooling='avg')
y = np.array(y)
X_orig,_,_ = SPoC(cnn_model,[trainloader_orig,testloader_orig],pooling='max')

pca = PCA(n_components=128,random_state=42)
Xr_orig = pca.fit_transform(X_orig)
Dinv = sp.spdiags(1/pca.singular_values_,0,Xr_orig.shape[1],Xr_orig.shape[1])
Xr_orig = Xr_orig@Dinv
Xr_orig = normalize(Xr_orig)
Xr_orig = torch.tensor(Xr_orig).to('cuda')
A_knn_orig = knn_cuda_graph(Xr_orig,5,256)
G = (A_knn_orig>0).astype(np.float64)

100%|██████████| 235/235 [00:03<00:00, 73.21it/s] 
100%|██████████| 40/40 [00:01<00:00, 29.27it/s]
100%|██████████| 235/235 [00:03<00:00, 72.37it/s] 
100%|██████████| 40/40 [00:01<00:00, 29.23it/s]
100%|██████████| 274/274 [00:17<00:00, 15.72it/s]


In [15]:
nn_model = NN_model()
nn_model.preds = preds_orig
nn_model.labels = y
nn_model.A = G
nn_model.train_mask = np.zeros(G.shape[0])
nn_model.train_mask[train_nodes] = 1
nn_model.val_mask = np.zeros(G.shape[0])
nn_model.val_mask[val_nodes] = 1
nn_model.test_mask = np.zeros(G.shape[0])
nn_model.test_mask[test_nodes] = 1
smallest_component = 100
overlap = 0.025
labels_to_eval = list(range(preds_orig.shape[1]))
GTDA_record = compute_reeb(GTDA,nn_model,labels_to_eval,smallest_component,overlap,extra_lens=None,
    node_size_thd=5,reeb_component_thd=5,nprocs=10,device='cuda')


Preprocess lens..
Merge reeb nodes...
Build reeb graph...
Total time for building reeb graph is 13.817661762237549 seconds
Compute mixing rate for each sample


In [16]:
from GTDA.GTDA_utils import save_to_json
label_to_name = {
    0:"T-shirt/top",
    1:"Trouser",
    2:"Pullover",
    3:"Dress",
    4:"Coat",
    5:"Sandal",
    6:"Shirt",
    7:"Sneaker",
    8:"Bag",
    9:"Ankle boot",
}
save_to_json(GTDA_record, nn_model, ".", label_to_name)