# Inference
## Import packages

In [1]:
import os
from functools import wraps
import gc
import random
from pathlib import Path
from datetime import datetime
from typing import List,Tuple

# scientific
import numpy as np
import matplotlib.pyplot as plt
import plotly.express as px
from matplotlib import cm
import matplotlib
import tifffile as tiff
import pandas as pd
from scipy.ndimage import zoom
from sklearn.metrics import r2_score 

# torch
import torch
from torch.utils.data import Dataset
from torch import nn
from torch import Tensor
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader, random_split
from torch.optim import (SGD,
                         Adam, )

from torchvision.transforms import (ToTensor,
                                    Compose,
                                    RandomHorizontalFlip,
                                    RandomVerticalFlip,
                                    RandomRotation,
                                    Normalize,
                                    ToPILImage)

## Paths

In [15]:
save_path = Path("/home/lizard/Documents/Code/Project/stoneRegression/result/test") # save path
test_data_path = Path("/home/lizard/Documents/Code/Project/stoneRegression/data/Dataset_Binary/Resolution_3um/Res-03") # image path
test_label_path = Path("/home/lizard/Documents/Code/Project/stoneRegression/data/myLabels.xlsx") # label path
model_path = Path("/home/lizard/Documents/Code/Project/stoneRegression/cpt/20211225-152604/model.pt") # saved model path

## Utility

In [2]:
def fix_all_seeds(seed):
    np.random.seed(seed)
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def flush_and_gc(f):
    @wraps(f)
    def g(*args, **kwargs):
        torch.cuda.empty_cache()
        gc.collect()
        return f(*args, **kwargs)

    return g

## Runtime

In [13]:
n_worker = 4
seed = 2021
BATCH_SIZE = 2
device = torch.device(f"cuda" if torch.cuda.is_available() else "cpu")
fix_all_seeds(seed=seed)

## Dataloader

In [4]:
class Stone(Dataset):
    def __init__(self, images_dir: List[Path], label_xlx: Path, transformers):
        self._transformers = transformers
        self._ds_root = images_dir
        self._label_root = label_xlx
        self._lb = np.squeeze(pd.read_excel(str(self._label_root)).to_numpy(), axis=-1)
        lb = [self._lb for _ in range(len(images_dir))]
        self._lb = np.concatenate(lb,axis=0)
        self._f_list = []
        for d_path in self._ds_root:

            f_list = list(d_path.glob("*.tif"))
            f_list.sort(key=lambda p: int(p.stem.split("-")[0]))
            self._f_list+= f_list

        assert len(self._f_list) == len(self._lb)

    def __len__(self):
        return len(self._f_list)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        n_d_im = tiff.imread(str(self._f_list[idx]))
#         n_d_im = zoom(n_d_im, (0.5, 0.5, 0.5))
        n_d_im = zoom(n_d_im, (0.4, 0.4, 0.4))

        lb = torch.as_tensor(self._lb[idx]*10**15)
        n_d_im = (n_d_im - n_d_im.min())/(n_d_im.max() - n_d_im.min())
        if self._transformers is not None:
            n_d_im = torch.unsqueeze(self._transformers(n_d_im.astype(np.float32)), 0)
            
        return n_d_im, lb

## Transform

In [10]:
def get_simple_transformers(n_channel=300,mean=.5, std=.5):
      return Compose([
            ToTensor(),
            Normalize(mean=[mean] * n_channel, std=[std] * n_channel)
      ])

## Presented models

In [11]:
class Model3DV1(nn.Module):
    def __init__(self, n_channels, n_feature):
        super(Model3DV1, self).__init__()

        self._model = nn.Sequential(
            nn.Conv3d(in_channels=n_channels, out_channels=n_feature, kernel_size=3),
            nn.BatchNorm3d(n_feature),
            nn.ReLU(inplace=True),
            nn.MaxPool3d(kernel_size=2),
            
            nn.Conv3d(in_channels=n_feature, out_channels=n_feature * 2, kernel_size=3),
            nn.BatchNorm3d(n_feature*2),
            nn.ReLU(inplace=True),
            nn.MaxPool3d(kernel_size=2),
            
            nn.Conv3d(in_channels=n_feature * 2, out_channels=n_feature * 4, kernel_size=3),
            nn.BatchNorm3d(n_feature*4),
            nn.ReLU(inplace=True),
            nn.MaxPool3d(kernel_size=3),
            
            nn.Flatten(),
            nn.Dropout(0.5),
#             nn.Linear(in_features=170368, out_features=1),
            
            nn.Linear(in_features=65536, out_features=1),
#             nn.ReLU()
        )

    def forward(self, x: Tensor) -> Tensor:
        return self._model(x)

In [16]:
test_ds = Stone(images_dir = [test_data_path],label_xlx = test_label_path,transformers=get_simple_transformers(120))
test_loader = DataLoader(test_ds,
                          batch_size=BATCH_SIZE,
                          shuffle=False,
                          num_workers=n_worker)

model_state_dict = torch.load(str(model_path))
print(model_state_dict)

RuntimeError: CUDA error: out of memory
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.