In [None]:
import torch
import torch.nn as nn
import torchvision

def widthN_to_bsN(x, N):
    # x: (bs, c, h, w*N) -> (N*bs, c, h, w)
    bs, c, h, w = x.shape
    w = w // N
    x = torch.stack(x.split(w, dim=-1), dim=1).view(-1, c, h, w)
    return x

def bsN_to_widthN(x, N):
    # x: (N*bs, c, h, w) -> (bs, c, h, w*N)
    bs, c, h, w = x.shape
    bs = bs // N
    x = torch.stack(x.split(N, dim=0)).permute(0, 2, 3, 1, 4).reshape(bs, c, h, -1)
    return x

def bsN_to_5dim(x, N):
    # x: (N*bs, c, h, w) -> (bs, N, c, h, w)
    bs, c, h, w = x.shape
    bs = bs // N
    x = torch.stack(x.split(N, dim=0))
    return x


class Resnet18(nn.Module):
    def __init__(self, pretrain=True, classes=6, N=25):
        super(Resnet18, self).__init__()
        print("Using Network = resnet18 classifier", classes)
        self.N = N

        model = torchvision.models.resnet18(pretrained=pretrain)
        self.encoder = nn.Sequential(*list(model.children())[:-2])
        feat_channel = list(model.children())[-1].in_features
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(feat_channel, classes),
        )

    def forward(self, x):
        f = self.encoder(widthN_to_bsN(x, self.N))
        return self.classifier(bsN_to_widthN(f, self.N))


class Resnet18_asigm(nn.Module):
    def __init__(self, pretrain=True, classes=6, N=25):
        super(Resnet18_asigm, self).__init__()
        print("Using Network = resnet18 + learned a (sigmoid)")
        self.N = N
        self.classes = classes

        model = torchvision.models.resnet18(pretrained=pretrain)
        self.encoder = nn.Sequential(*list(model.children())[:-2])
        feat_channel = list(model.children())[-1].in_features

        self.gap = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
        )
        self.penultimate = nn.Sequential(
            nn.Linear(feat_channel, classes),
            nn.Softmax(dim=1),
        )
        self.final = nn.Sequential(
            nn.Linear(feat_channel, classes),
        )

    def forward(self, x):
        f = self.encoder(widthN_to_bsN(x, self.N))
        f = self.gap(bsN_to_widthN(f, self.N))
        p = self.penultimate(f)
        a = self.final(f)
        return (self.classes - 1) * (a * p).sum(dim=1).sigmoid()


class Resnet18_afix(nn.Module):
    def __init__(self, pretrain=True, classes=6, N=25):
        super(Resnet18_afix, self).__init__()
        print("Using Network = resnet18 + fixed a")
        self.N = N
        self.classes = classes

        model = torchvision.models.resnet18(pretrained=pretrain)
        self.encoder = nn.Sequential(*list(model.children())[:-2])
        feat_channel = list(model.children())[-1].in_features

        self.gap = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
        )
        self.penultimate = nn.Sequential(
            nn.Linear(feat_channel, classes),
            nn.Softmax(dim=1),
        )
        self.register_buffer('a', torch.arange(classes))

    def forward(self, x):
        f = self.encoder(widthN_to_bsN(x, self.N))
        f = self.gap(bsN_to_widthN(f, self.N))
        pred_softmax = self.penultimate(f)
        pred_label = (pred_softmax * self.a).sum(dim=1)
        return pred_label, pred_softmax


class Resnet50_afix(nn.Module):
    def __init__(self, pretrain=True, classes=6, N=25):
        super(Resnet50_afix, self).__init__()
        print("Using Network = resnet50 + fixed a")
        self.N = N
        self.classes = classes

        model = torchvision.models.resnet50(pretrained=pretrain)
        self.encoder = nn.Sequential(*list(model.children())[:-2])
        feat_channel = list(model.children())[-1].in_features

        self.gap = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
        )
        self.penultimate = nn.Sequential(
            nn.Linear(feat_channel, classes),
            nn.Softmax(dim=1),
        )
        self.register_buffer('a', torch.arange(classes))

    def forward(self, x):
        f = self.encoder(widthN_to_bsN(x, self.N))
        f = self.gap(bsN_to_widthN(f, self.N))
        pred_softmax = self.penultimate(f)
        pred_label = (pred_softmax * self.a).sum(dim=1)
        return pred_label, pred_softmax


def get_model_by_name(name, **kwargs):
    models = {
        'res18': Resnet18,
        'res18_asigm': Resnet18_asigm,
        'res18_afix': Resnet18_afix,
        'res50_afix': Resnet50_afix,
    }
    assert name in models
    m = models[name]
    return m(**kwargs)

In [None]:
!pip uninstall --yes scikit-image
!pip install ../input/skimage0162/scikit_image-0.16.2-cp37-cp37m-manylinux1_x86_64.whl
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import transforms
import os
import csv
import random
import cv2
import numpy as np
import time
import albumentations

import skimage.io
from skimage import transform

sz = 256
N = 64
sN = 8
scale_factor = 1


def tile(img):
    shape = img.shape
    pad0, pad1 = (sz - shape[0] % sz) % sz, (sz - shape[1] % sz) % sz
    img = np.pad(img, [[pad0 // 2, pad0 - pad0 // 2], [pad1 // 2, pad1 - pad1 // 2], [0, 0]],
                 constant_values=255)
    #     mask = np.pad(mask,[[pad0//2,pad0-pad0//2],[pad1//2,pad1-pad1//2],[0,0]],
    #                 constant_values=0)
    img = img.reshape(img.shape[0] // sz, sz, img.shape[1] // sz, sz, 3)
    img = img.transpose(0, 2, 1, 3, 4).reshape(-1, sz, sz, 3)
    #     mask = mask.reshape(mask.shape[0]//sz,sz,mask.shape[1]//sz,sz,3)
    #     mask = mask.transpose(0,2,1,3,4).reshape(-1,sz,sz,3)
    if len(img) < N:
        #         mask = np.pad(mask,[[0,N-len(img)],[0,0],[0,0],[0,0]],constant_values=0)
        img = np.pad(img, [[0, N - len(img)], [0, 0], [0, 0], [0, 0]], constant_values=255)
    idxs = np.argsort(img.reshape(img.shape[0], -1).sum(-1))[:N]
    img = img[idxs]
    #     mask = mask[idxs]
    av_blk = 0
    for i in range(N):
        ss = (img[i] < 240).sum()
        if ss / sz / sz > 0.5:
            av_blk += 1
    img = np.array(img).reshape(sN, sN, sz, sz, 3).transpose(0, 2, 1, 3, 4).reshape(sN * sz, sN * sz, 3)
    #     mask = np.array(mask).reshape(sN,sN,sz,sz,3).transpose(0,2,1,3,4).reshape(sN*sz,sN*sz,3)
    #     for i in range(len(img)):
    #         imgs.append(img)
    #         result.append({'img':img[i], 'mask':mask[i], 'idx':i})
    return img, av_blk


class PANDA_dataset(Dataset):
    def __init__(self, data_path="./data", split="train", valid_blocks=25, kfold=0, nfold=5, IMAGE_DIR="images_all"):
        super(PANDA_dataset, self).__init__()
        self.split = split
        self.path = data_path
        self.data = []
        self.valid_blocks = valid_blocks
        self.toTensor = transforms.ToTensor()
        self.IMAGE_DIR = IMAGE_DIR
        self.transform = albumentations.Compose([
            albumentations.Transpose(p=0.5),
            albumentations.VerticalFlip(p=0.5),
            albumentations.HorizontalFlip(p=0.5),
        ])
        if split == "train" or split == "valid":
            with open(os.path.join(data_path, "train.csv"), 'r') as f:
                csv_f = csv.DictReader(f)
                for row in csv_f:
                    self.data.append(
                        [row['image_id'], row['data_provider'], int(row['isup_grade']), row['gleason_score']])
            self.data_check()
            self.data.sort()
            valid_begin = int(len(self.data) / nfold * kfold)
            valid_end = int(len(self.data) / nfold * (kfold + 1))
            if split == "valid":
                self.data = self.data[valid_begin:valid_end]
                print("Using split = {}, data = [{} ... {}], size = {}.".format(
                    split, self.data[0][0], self.data[-1][0], len(self.data)))
            else:
                self.data = self.data[:valid_begin] + self.data[valid_end:]
                print("Using split = {}, data = [{} ...(val)... {}], size = {}.".format(
                    split, self.data[0][0], self.data[-1][0], len(self.data)))
        else:
            print(os.path.join(self.path, self.IMAGE_DIR))
            for root, dirs, files in os.walk(os.path.join(self.path, self.IMAGE_DIR)):
                for file in files:
                    self.data.append(file.split('.')[0])

    def data_check(self):
        missing = []
        for x in self.data:
            name = x[0]
            if not os.path.exists(os.path.join(self.path, self.IMAGE_DIR, name + ".png")):
                missing.append(name)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, x):
        if self.split == "test":
            x = self.data[x]
            img = skimage.io.MultiImage(os.path.join(self.path, self.IMAGE_DIR, x + ".tiff"))[1]
            img, av_blk = tile(img)
            img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
            img = img.reshape(8, 256, 8, 256, 3).transpose(0, 2, 1, 3, 4).reshape(-1, 256, 256, 3)
            img = img[0:self.valid_blocks]
            img = img.reshape(-1, 256, 3).transpose(1, 0, 2)
            img = self.toTensor(255 - img)
            return img, x
        else:
            x = self.data[x]
            img = cv2.imread(os.path.join(self.path, self.IMAGE_DIR, x[0] + ".png"))
            img = img.reshape(8, 256, 8, 256, 3).transpose(0, 2, 1, 3, 4).reshape(-1, 256, 256, 3)
            weights = []
            for i in range(img.shape[0]):
                weights.append((img[i] < 240).sum() / 256 / 256 + 1e-5)  # [0.0 ~ 3.0]
                if self.split == "train":
                    img[i] = self.transform(image=img[i])["image"]
            weights = np.array(weights)
            weights /= weights.sum()
            idx = np.random.choice(np.arange(img.shape[0]), self.valid_blocks, p=weights, replace=False)
            img = img[idx]
            img = img.reshape(-1, 256, 3).transpose(1, 0, 2)
            img = self.toTensor(255 - img)
            return img, x[2]


def build_data(batch_size, num_worker, **kwargs):
    return DataLoader(PANDA_dataset(**kwargs), batch_size, shuffle=True,
                      num_workers=num_worker)

In [None]:
import torch.nn as nn
import torchvision
from tqdm import tqdm
import math
import cv2
import torch
import numpy as np
import yaml
import os
import random
from sklearn.metrics import cohen_kappa_score
from tqdm import tqdm
import csv

RANDOM_SEED = 1337
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
torch.cuda.manual_seed_all(RANDOM_SEED)
torch.backends.cudnn.deterministic = True

def open_config(root):
    f = open(os.path.join(root, "config.yaml"))
    config = yaml.load(f, Loader=yaml.FullLoader)
    return config


def load(models, epoch, root):
    def _detect_latest():
        checkpoints = os.listdir(os.path.join(root, "logs"))
        checkpoints = [f for f in checkpoints if f.startswith("model-epoch-") and f.endswith(".pth")]
        checkpoints = [int(f[len("epoch-"):-len(".pth")]) for f in checkpoints]
        checkpoints = sorted(checkpoints)
        _epoch = checkpoints[-1] if len(checkpoints) > 0 else None
        return _epoch

    if epoch == -1:
        epoch = _detect_latest()
    if epoch is None:
        return -1
    for name, model in models.items():
        ckpt = torch.load(os.path.join(root, "logs/ckpts/" + name + "_epoch-{}.pth".format(epoch)))
        ckpt = {k[7:]: v for k, v in ckpt.items()}
        model.load_state_dict(ckpt)
        print("load model: {} from epoch: {}".format(name, epoch))
    # print("loaded from epoch: {}".format(epoch))
    return epoch


def test(args, root):
    f = open("submission.csv", "w", newline="")
    print(args)

    args_data = args['data']
    real_path="prostate-cancer-grade-assessment/test_images"
    dataloader = build_data(4, num_worker=4, data_path = "/kaggle/input/",split = "test", valid_blocks=args_data['valid_blocks'],
                            IMAGE_DIR=real_path)
    model = Resnet18_afix(N=args_data['valid_blocks'], pretrain=False).cuda()

    load_epoch = load({"model": model}, 75, root)

    preds = []
    names = []
    model.eval()
    with torch.no_grad():
        for (image, name) in tqdm(dataloader):
            image = image.cuda()
            fake = model(image)[0].round().int()
            for i in range(len(name)):
                names.append(name[i])
                preds.append(fake[i].item())
    writer = csv.writer(f)
    writer.writerow(("image_id", "isup_grade"))
    for i in range(len(names)):
        writer.writerow((names[i], preds[i]))
    print(preds)
    f.close()

if __name__ == "__main__":
    root="/kaggle/input/panda-exp-xwl/res18_fix"
#     root = "./resnet18_5e5"
    test(open_config(root), root)