pip install pandas pytorch-lightning torch-metrics scikit-learn scikit-image numpy matplotlib tqdm jupyterlab captum wandb opencv-python-headless timm seaborn plotly umap-learnm


In [1]:
import pandas as pd
import re
import numpy as np
from pathlib import Path
from mouse_facial_expressions.paths import *

project_dir = Path('..').resolve()
frames_dir = Path(get_extracted_frames_folder())

In [2]:

def load_data():
    dataset_path = Path('/home/andre/berlin2022/datasets').resolve() / 'CUv2'
    treatments = pd.read_csv(dataset_path / 'treatments.csv')
    videos = pd.read_csv(dataset_path / 'video-log.csv')
    
    # Get images
    imagepaths = dataset_path.glob('resized-images/*/*.png')
    
    # Merge datasets
    df = pd.DataFrame({'image': imagepaths})
    def get_video_name_from_imagepath(imagepath):
        m = re.match("(.*)_MGSframes", imagepath.parts[-2])
        return m.groups()[0] + '.mp4'
    
    df['video'] = df.image.apply(get_video_name_from_imagepath)
    df = df.merge(videos, on='video')
    df = df.merge(treatments, on='mouse')

    # Filter
    # df = df[df.stage != 'acclimation']  
    df = df[df.mouse != 18] # control mouse appeared sick in videos
    df = df[df.stage == '4h post injection']
    df = df[df.treatment.isin(['high', 'saline'])]
    
    # Label everything a 1
    df['label'] = np.ones(shape=df.shape[0], dtype=int)
    
    # Label control situations
    df.loc[df.stage == 'preinjection', 'label'] = 0
    df.loc[df.treatment == 'saline', 'label'] = 0
    
    return df

load_data()

Unnamed: 0,image,video,date,logged_video_start_time,mouse,stage,treatment,pre_experiment_cage,injection_time,injection_date,label
600,/home/andre/berlin2022/datasets/CUv2/resized-i...,Basler_acA1920-40um__23999063__20220614_152021...,2022-06-14,15:21:00,5,4h post injection,high,2,11:21:00,2022-06-14,1
601,/home/andre/berlin2022/datasets/CUv2/resized-i...,Basler_acA1920-40um__23999063__20220614_152021...,2022-06-14,15:21:00,5,4h post injection,high,2,11:21:00,2022-06-14,1
602,/home/andre/berlin2022/datasets/CUv2/resized-i...,Basler_acA1920-40um__23999063__20220614_152021...,2022-06-14,15:21:00,5,4h post injection,high,2,11:21:00,2022-06-14,1
603,/home/andre/berlin2022/datasets/CUv2/resized-i...,Basler_acA1920-40um__23999063__20220614_152021...,2022-06-14,15:21:00,5,4h post injection,high,2,11:21:00,2022-06-14,1
604,/home/andre/berlin2022/datasets/CUv2/resized-i...,Basler_acA1920-40um__23999063__20220614_152021...,2022-06-14,15:21:00,5,4h post injection,high,2,11:21:00,2022-06-14,1
...,...,...,...,...,...,...,...,...,...,...,...
9488,/home/andre/berlin2022/datasets/CUv2/resized-i...,Basler_acA1920-40um__23999063__20220617_154642...,2022-06-17,15:47:00,16,4h post injection,saline,8,11:47:00,2022-06-17,0
9489,/home/andre/berlin2022/datasets/CUv2/resized-i...,Basler_acA1920-40um__23999063__20220617_154642...,2022-06-17,15:47:00,16,4h post injection,saline,8,11:47:00,2022-06-17,0
9490,/home/andre/berlin2022/datasets/CUv2/resized-i...,Basler_acA1920-40um__23999063__20220617_154642...,2022-06-17,15:47:00,16,4h post injection,saline,8,11:47:00,2022-06-17,0
9491,/home/andre/berlin2022/datasets/CUv2/resized-i...,Basler_acA1920-40um__23999063__20220617_154642...,2022-06-17,15:47:00,16,4h post injection,saline,8,11:47:00,2022-06-17,0


In [3]:
treatments_df = pd.read_csv(project_dir / 'data/raw/treatments_20230627.csv')

frames_folder_df = pd.DataFrame(dict(image=list(frames_dir.glob('*/*.png'))))
frames_folder_df['video'] = frames_folder_df.image.apply(lambda x: x.parts[-2])
frames_folder_df['mouse'] = frames_folder_df.video.apply(lambda x: re.match('([mf]\d+)', x).group(1))
frames_folder_df['recording'] = frames_folder_df.video.apply(lambda x: int(re.match('.*rec(\d+)', x).group(1)))

raw_videos_df = pd.read_csv(project_dir / 'data/raw/raw_videos_20230627.csv')
raw_videos_df.recording = raw_videos_df.recording.fillna(-1).astype(int)
raw_videos_df['video_time'] = raw_videos_df.apply(lambda x: f"{x.hour:02}:{x.minutes:02}", axis=1)
raw_videos_df['mouse'] = raw_videos_df.animal

combined_df = treatments_df.merge(frames_folder_df, how='left', on='mouse')
combined_df = combined_df.merge(raw_videos_df, how='left', on=['mouse', 'recording'])

combined_df = combined_df[combined_df.mouse != 'm18']
combined_df = combined_df[combined_df.treatment.isin(['high', 'saline'])]
combined_df = combined_df[combined_df.recording.isin([1, 4])]

# Label everything a 1
combined_df['label'] = np.ones(shape=combined_df.shape[0], dtype=int)

# Label control situations
combined_df.loc[combined_df.recording == 1, 'label'] = 0
combined_df.loc[combined_df.treatment == 'saline', 'label'] = 0

combined_df.sample(20)

Unnamed: 0,mouse,date_of_birth,treatment,injection_time,notes,image,video,recording,camera,year,...,hour,minutes,seconds,animal,start,end,discard,Notes,video_time,label
11703,m12,16 April 2022,saline,12:00,approx timing,/backup/data/extracted_frames/20230627/m12_rec...,m12_rec1_preinjection,1,Basler_acA1920-40um,2022,...,11,39,9,m12,0:20,-1.0,,,11:39,0
3899,m4,11 January 2022,high,13:07,,/backup/data/extracted_frames/20230627/m4_rec4...,m4_rec4_4h-postinjection,4,Basler_acA1920-40um,2022,...,17,22,35,m4,0,-1.0,,re-recorded,17:22,1
33099,f1,13 November 22,saline,12:23,,/backup/data/extracted_frames/20230627/f1_rec4...,f1_rec4_4h-postinjection,4,Basler_acA1920-40um,2023,...,16,20,23,f1,,,,,16:20,0
35429,f3,13 November 22,high,12:53,,/backup/data/extracted_frames/20230627/f3_rec1...,f3_rec1_preinjection,1,Basler_acA1920-40um,2023,...,12,37,54,f3,,,,,12:37,0
6238,m7,25 March 2022,saline,11:50,,/backup/data/extracted_frames/20230627/m7_rec4...,m7_rec4_4h-postinjection,4,Basler_acA1920-40um,2022,...,15,48,29,m7,0:07,-1.0,,,15:48,0
4456,m5,11 January 2022,high,11:21,,/backup/data/extracted_frames/20230627/m5_rec4...,m5_rec4_4h-postinjection,4,Basler_acA1920-40um,2022,...,15,20,21,m5,0:07,-1.0,,,15:20,1
43538,f10,15 August 22,saline,12:08,,/backup/data/extracted_frames/20230627/f10_rec...,f10_rec4_4h-postinjection,4,Basler_acA1920-40um,2023,...,16,59,53,f10,,,,,16:59,0
33100,f1,13 November 22,saline,12:23,,/backup/data/extracted_frames/20230627/f1_rec4...,f1_rec4_4h-postinjection,4,Basler_acA1920-40um,2023,...,16,20,23,f1,,,,,16:20,0
46789,f12,25 October 22,high,12:38,,/backup/data/extracted_frames/20230627/f12_rec...,f12_rec1_preinjection,1,Basler_acA1920-40um,2023,...,12,21,10,f12,,,,,12:21,0
40740,f8,25 October 22,high,13:00,,/backup/data/extracted_frames/20230627/f8_rec4...,f8_rec4_4h-postinjection,4,Basler_acA1920-40um,2023,...,17,46,1,f8,,,,,17:46,1


In [4]:
import argparse
import os
import random
import re
from pathlib import Path
from sklearn.model_selection import GroupKFold
from pytorch_lightning.callbacks import LearningRateMonitor
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import sklearn
import torch
import torchvision
import wandb
from pytorch_lightning.callbacks import Callback, LearningRateMonitor
from skimage.io import imread
from sklearn.model_selection import GroupKFold, LeaveOneGroupOut
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from torch.utils.data import DataLoader, Dataset, IterableDataset
from torchmetrics import Accuracy, ConfusionMatrix
from torchmetrics.functional import accuracy
from torchvision.models import ResNet50_Weights, resnet50
from torchvision.transforms import (Compose, Normalize, RandAugment,
                                    ToPILImage, ToTensor, TrivialAugmentWide)

class FP(torch.nn.Module):    
    def __call__(self, preds, target):
        preds = torch.argmax(preds, 1)
        values = torch.logical_and((target == 0), (preds == 1))
        return torch.sum(values).int() 

class TP(torch.nn.Module):
    def __call__(self, preds, target):
        preds = torch.argmax(preds, 1)
        values = torch.logical_and((target == 1), (preds == 1))
        return torch.sum(values).int()
    
class FN(torch.nn.Module):
    def __call__(self, preds, target):
        preds = torch.argmax(preds, 1)
        values = torch.logical_and((target == 1), (preds == 0))
        return torch.sum(values).int()
    
class TN(torch.nn.Module):
    def __call__(self, preds, target):
        preds = torch.argmax(preds, 1)
        values = torch.logical_and((target == 0), (preds == 0))
        return torch.sum(values).int()
    
class DeepSet(pl.LightningModule):
    def __init__(self, config, class_weights):
        super().__init__()
        self.config = config
        self.model_ = torchvision.models.resnet34(weights=torchvision.models.ResNet34_Weights.IMAGENET1K_V1)
        # for parameter in self.model_.parameters():
        #     parameter.requires_grad = False
            
        in_features = self.model_.fc.in_features
        n_classes = 2 
        features = 1
        self.model_.fc = torch.nn.Linear(in_features, features)
        self.fc = torch.nn.Linear(features, n_classes) # The purpose of this layer is really just to add bias
        self.criterion = torch.nn.CrossEntropyLoss(
            weight=class_weights, label_smoothing=config['label_smoothing'])
        
        self.metrics = torch.nn.ModuleDict({
            'accuracy': Accuracy(task='multiclass', num_classes=2),
            'tn': TN(),
            'fn': FN(),
            'tp': TP(),
            'fp': FP()
        })
        
    def configure_optimizers(self):
        lr = self.config['learning_rate']
        warmup_steps = self.config['warmup_steps']
        warmup_decay = self.config['warmup_decay']
        total_steps = self.config['total_steps']
        optimizer = torch.optim.SGD(params=self.parameters(), lr=lr, momentum=0.9, weight_decay=warmup_decay)
        scheduler = torch.optim.lr_scheduler.SequentialLR(
            optimizer,
            schedulers=[
                torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=warmup_decay, total_iters=warmup_steps), 
                torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, self.trainer.max_steps-warmup_steps),
            ], 
            milestones=[warmup_steps]
        )
        
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "interval": "step",
                "frequency": 1
            }
        }
    
    def shared_step_(self, stage, batch, batch_idx):
        x = batch['image'].float()
        y = batch['label'].long()
        
        # Flatten (batch, set, images) -> (batch x set, images)
        s = x.shape
        x = x.flatten(0, 1)
        
        # Pass through model
        z = self.model_(x)
        
        # Reshape into (batch x set, images) -> (batch, set, images)
        z = z.reshape(*s[:2], z.shape[-1])
        z = z.mean(dim=1) # Mean over each image (dims are batch, image, class preds)
        
        # Finally add bias
        y_hat = self.fc(z)
        loss = self.criterion(y_hat, y)
        self.log(stage+"_loss", loss.item(), prog_bar=True)
        
        with torch.no_grad():
            for k in self.metrics:
                self.log(stage+"_"+k, self.metrics[k](y_hat, y), prog_bar=True)
            
        return loss
    
    def training_step(self, batch, batch_idx):
        cur_lr = self.trainer.optimizers[0].param_groups[0]['lr']
        self.log("lr", cur_lr, prog_bar=True, on_step=True)
        return self.shared_step_("train", batch, batch_idx)
    
    def validation_step(self, batch, batch_idx):
        return self.shared_step_("val", batch, batch_idx)
    
    def test_step(self, batch, batch_idx):
        return self.shared_step_("test", batch, batch_idx)

In [5]:
import argparse
import os
import random
import re
from pathlib import Path
from sklearn.model_selection import StratifiedGroupKFold
from pytorch_lightning.callbacks import LearningRateMonitor
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import sklearn
import torch
import torchvision
import wandb
from pytorch_lightning.callbacks import Callback, LearningRateMonitor
from skimage.io import imread
from sklearn.model_selection import GroupKFold, LeaveOneGroupOut
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from torch.utils.data import DataLoader, Dataset, IterableDataset
from torchmetrics import Accuracy, ConfusionMatrix
from torchmetrics.functional import accuracy
from torchvision.models import ResNet50_Weights, resnet50
from torchvision.transforms import (Compose, Normalize, RandAugment,
                                    ToPILImage, ToTensor, TrivialAugmentWide)

# parser = argparse.ArgumentParser(description='Train a facial expression model')
# parser.add_argument("--param_version", default='0', type=str)
# parser.add_argument("--learning_rate", default=0.01, type=float)
# parser.add_argument("--total_steps", default=10000, type=int)
# parser.add_argument("--warmup_steps", default=1000, type=int)
# parser.add_argument("--cross_validation_fold", default=0, type=int)
# parser.add_argument("--shuffle", default=0, type=int)
# parser.add_argument("--cross_validation_folds", default=9, type=int)
# parser.add_argument("--weight_decay", default=1e-4, type=float)
# parser.add_argument("--warmup_decay", default=0.0001, type=float)
# parser.add_argument("--frames_per_set", default=5, type=int)
# parser.add_argument("--batch_size", default=10, type=int)
# parser.add_argument("--label_smoothing", default=0.1, type=float)
# parser.add_argument("--model", default="deepset", choices=["deepset", "resnet"])
# args = parser.parse_args()

def load_data():
    treatments_df = pd.read_csv(project_dir / 'data/raw/treatments_20230627.csv')

    frames_folder_df = pd.DataFrame(dict(image=list(frames_dir.glob('*/*.png'))))
    frames_folder_df['video'] = frames_folder_df.image.apply(lambda x: x.parts[-2])
    frames_folder_df['mouse'] = frames_folder_df.video.apply(lambda x: re.match('([mf]\d+)', x).group(1))
    frames_folder_df['recording'] = frames_folder_df.video.apply(lambda x: int(re.match('.*rec(\d+)', x).group(1)))

    raw_videos_df = pd.read_csv(project_dir / 'data/raw/raw_videos_20230627.csv')
    raw_videos_df.recording = raw_videos_df.recording.fillna(-1).astype(int)
    raw_videos_df['video_time'] = raw_videos_df.apply(lambda x: f"{x.hour:02}:{x.minutes:02}", axis=1)
    raw_videos_df['mouse'] = raw_videos_df.animal

    combined_df = treatments_df.merge(frames_folder_df, how='left', on='mouse')
    combined_df = combined_df.merge(raw_videos_df, how='left', on=['mouse', 'recording'])

    combined_df = combined_df[combined_df.mouse != 'm18']
    combined_df = combined_df[combined_df.treatment.isin(['high', 'saline'])]
    combined_df = combined_df[combined_df.recording.isin([1, 4])]

    # Label everything a 1
    combined_df['label'] = np.ones(shape=combined_df.shape[0], dtype=int)

    # Label control situations
    combined_df.loc[combined_df.recording == 1, 'label'] = 0
    combined_df.loc[combined_df.treatment == 'saline', 'label'] = 0
    
    return combined_df

class MyIterDataset(IterableDataset):
    def __init__(self, df, image_transform, frames_per_sample):
        super().__init__()
        self.df = df
        self.video_groups = df.groupby('video')
        agg = self.video_groups.agg({'video': 'first', 'label': 'first'})
        self.videos = agg.video.tolist()
        self.video_labels = agg.label
        self.frames_per_sample = frames_per_sample
        self.image_transform = image_transform
            
    def random_video(self):
        index = random.randint(0, len(self.videos)-1)
        return self.videos[index]
    
    def random_frames(self, video):
        return self.video_groups.get_group(video).sample(self.frames_per_sample)
        
    def get_image(self, imagepath):
        return self.image_transform(imread(imagepath))
    
    def __iter__(self):
        while True:
            video = self.random_video()
            random_frames = self.random_frames(video)
            label = random_frames.label.iloc[0]
            out = {
                'label': label,
                'image': torch.stack(random_frames.image.apply(self.get_image).tolist())
            }   
            yield out
    
class TestableDataset(Dataset):
    def __init__(self, iterable_dataloader, max_iterations=1000):
        super().__init__()
        self.iterable_dataloader = iterable_dataloader
        self.max_iterations = max_iterations
        self.init_iter()
        
    def init_iter(self):
        self.iter_dataloader = iter(self.iterable_dataloader)
        
    def __len__(self):
        return self.max_iterations
    
    def __getitem__(self, index):
        return next(self.iter_dataloader)

In [16]:
df = load_data()

# display(df.groupby('video').count()['image'])
cv = StratifiedGroupKFold(5)
splits = list(cv.split(df.index, groups=df.mouse, y=df.label))

for split_index, split in enumerate(splits):
    print('Split Index')
# train, test = splits[config['cross_validation_fold']]
    train, test = split
    train_df = df.loc[df.index[train]]
    test_df = df.loc[df.index[test]]
    
    print(train_df.mouse.unique(), test_df.mouse.unique())

Split Index
['m3' 'm4' 'm5' 'm7' 'm9' 'm15' 'm16' 'f3' 'f7' 'f8' 'f10' 'f12' 'f16'
 'f31' 'f32'] ['m12' 'm19' 'f1' 'f15']
Split Index
['m3' 'm7' 'm9' 'm12' 'm15' 'm19' 'f1' 'f3' 'f7' 'f8' 'f10' 'f12' 'f15'
 'f31' 'f32'] ['m4' 'm5' 'm16' 'f16']
Split Index
['m3' 'm4' 'm5' 'm7' 'm9' 'm12' 'm15' 'm16' 'm19' 'f1' 'f8' 'f10' 'f12'
 'f15' 'f16' 'f31' 'f32'] ['f3' 'f7']
Split Index
['m4' 'm5' 'm7' 'm12' 'm15' 'm16' 'm19' 'f1' 'f3' 'f7' 'f10' 'f15' 'f16'
 'f31'] ['m3' 'm9' 'f8' 'f12' 'f32']
Split Index
['m3' 'm4' 'm5' 'm9' 'm12' 'm16' 'm19' 'f1' 'f3' 'f7' 'f8' 'f12' 'f15'
 'f16' 'f32'] ['m7' 'm15' 'f10' 'f31']


In [None]:
config = dict(
    param_version=0,
    learning_rate=0.01,
    total_steps=5000,
    warmup_steps=200,
    cross_validation_fold=0,
    shuffle=0,
    cross_validation_folds=9,
    weight_decay=1e-4,
    warmup_decay=0.0001,
    frames_per_set=5,
    batch_size=10,
    label_smoothing=0.1,
    model='deepset'
)

df = load_data()
cv = StratifiedGroupKFold(config['cross_validation_folds'])
splits = list(cv.split(df.index, groups=df.mouse, y=df.label))

for split_index, split in enumerate(splits):
    print('Split Index')
# train, test = splits[config['cross_validation_fold']]
    train, test = split
    train_df = df.loc[df.index[train]]
    test_df = df.loc[df.index[test]]

    from IPython.display import display
    display(train_df.groupby('label').count()['image'])
    display(test_df.groupby('label').count()['image'])

    train_augmentation = torchvision.transforms.Compose([
        torchvision.transforms.ToPILImage(),
        torchvision.transforms.TrivialAugmentWide(),
        torchvision.transforms.ToTensor()
    ])

    test_augmentation = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
    ])

    train_dataset = MyIterDataset(
        train_df, frames_per_sample=config['frames_per_set'], image_transform=train_augmentation)
    test_dataset = MyIterDataset(
        test_df, frames_per_sample=config['frames_per_set'], image_transform=test_augmentation)

    train_dataloader = DataLoader(train_dataset, batch_size=config['batch_size'], num_workers=6)
    test_dataloader = DataLoader(test_dataset, batch_size=config['batch_size'], num_workers=6)

    weights = sklearn.utils.class_weight.compute_class_weight(
        'balanced', 
        classes=np.unique(train_dataset.video_labels),
        y=train_dataset.video_labels
    )
            
    loggers = [
        # pl.loggers.WandbLogger()
    ]
    # for logger in loggers:
    #     logger.log_hyperparams(config)
        
    callbacks = [
        # LearningRateMonitor()
    ]
    trainer = pl.Trainer(
        max_steps=config['total_steps'],
        accelerator='gpu', devices=[0], 
        val_check_interval=1000, limit_val_batches=100,
        logger = loggers,
        enable_checkpointing=False,
        callbacks=callbacks
    )

    model_classes = dict(deepset=DeepSet)
    model_class = model_classes[config['model']]
    model = model_class(config, class_weights=torch.from_numpy(weights).float())
    trainer.fit(model, train_dataloader, test_dataloader)
    
    actually_testable_dataset = TestableDataset(test_dataloader)
    trainer.test(model, actually_testable_dataset)

Split Index


label
0    5000
1    2400
Name: image, dtype: int64

label
0    800
1    200
Name: image, dtype: int64

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3080 Ti') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type             | Params
-----------------------------------------------
0 | model_    | ResNet           | 21.3 M
1 | fc        | Linear           | 4     
2 | criterion | CrossEntropyLoss | 0     
3 | metrics   | ModuleDict       | 0     
-----------------------------------------------
21.3 M    Trainable params
0         Non-trainable params
21.3 M    Total params
85.141    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]



Training: 0it [00:00, ?it/s]



Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_steps=5000` reached.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: 0it [00:00, ?it/s]



────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      test_accuracy         0.9968000054359436
         test_fn                    0.0
         test_fp            0.03200000151991844
        test_loss           0.29229822754859924
         test_tn            7.4679999351501465
         test_tp                    2.5
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
Split Index


label
0    5200
1    2200
Name: image, dtype: int64

label
0    600
1    400
Name: image, dtype: int64

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type             | Params
-----------------------------------------------
0 | model_    | ResNet           | 21.3 M
1 | fc        | Linear           | 4     
2 | criterion | CrossEntropyLoss | 0     
3 | metrics   | ModuleDict       | 0     
-----------------------------------------------
21.3 M    Trainable params
0         Non-trainable params
21.3 M    Total params
85.141    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]



Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_steps=5000` reached.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: 0it [00:00, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      test_accuracy         0.9998999834060669
         test_fn           0.0010000000474974513
         test_fp                    0.0
        test_loss           0.2829274535179138
         test_tn             7.484000205993652
         test_tp             2.515000104904175
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
Split Index


label
0    5200
1    2000
Name: image, dtype: int64

label
0    600
1    600
Name: image, dtype: int64

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type             | Params
-----------------------------------------------
0 | model_    | ResNet           | 21.3 M
1 | fc        | Linear           | 4     
2 | criterion | CrossEntropyLoss | 0     
3 | metrics   | ModuleDict       | 0     
-----------------------------------------------
21.3 M    Trainable params
0         Non-trainable params
21.3 M    Total params
85.141    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]



Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_steps=5000` reached.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: 0it [00:00, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      test_accuracy         0.9936000108718872
         test_fn                    0.0
         test_fp            0.06400000303983688
        test_loss           0.2867034673690796
         test_tn             7.456999778747559
         test_tp            2.4790000915527344
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
Split Index


label
0    5000
1    2200
Name: image, dtype: int64

label
0    800
1    400
Name: image, dtype: int64

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type             | Params
-----------------------------------------------
0 | model_    | ResNet           | 21.3 M
1 | fc        | Linear           | 4     
2 | criterion | CrossEntropyLoss | 0     
3 | metrics   | ModuleDict       | 0     
-----------------------------------------------
21.3 M    Trainable params
0         Non-trainable params
21.3 M    Total params
85.141    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]



Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: 0it [00:00, ?it/s]