In [None]:
import cv2
import numpy as np
from matplotlib import pyplot as plt
import os
import csv
# plt.rcParams["figure.figsize"] = (15, 15)


def show(i):
    plt.imshow(i)
    plt.show()

In [None]:
def generate_small_dataset(path):
    IMAGE_DIR = "image_small"
    rows = []
    with open(os.path.join(path, "train.csv"), 'r') as f:
        csv_f = csv.DictReader(f)
        for row in csv_f:
            if os.path.exists(os.path.join(path, IMAGE_DIR, row['image_id'] + '.png')):
                rows.append(row)
    with open(os.path.join(path, "train_small.csv"), 'w') as f:
        csv_f = csv.DictWriter(f, fieldnames=['image_id', 'data_provider', 'isup_grade', 'gleason_score'])
        csv_f.writeheader()
        for row in rows:
            csv_f.writerow(row)

generate_small_dataset('./data/')

In [None]:
def check_random_weighted_choice():
    path = "./data/image_small/0a0f8e20b1222b69416301444b117678.png"
    img = cv2.imread(path)
    show(img)
    img = img.reshape(8, 256, 8, 256, 3).transpose(0, 2, 1, 3, 4).reshape(-1, 256, 256, 3)
    weights = []
    for i in range(img.shape[0]):
        w = (img[i] < 240).sum() / 256 / 256 + 1e-5
        # print(w)
        weights.append(w)
    weights = np.array(weights)
    weights /= weights.sum()
    idx = np.random.choice(np.arange(img.shape[0]), 25, p=weights, replace=False)
    print("choose index =", sorted(idx))
    img = img[idx]
    img = img.reshape(5, 5, 256, 256, 3).transpose(0, 2, 1, 3, 4).reshape(5*256, 5*256, 3)
    show(img)

check_random_weighted_choice()

In [None]:
from data.data_builder import build_data
from nets.base import widthN_to_bsN, bsN_to_widthN, Resnet18
import torch
from torchvision.transforms import transforms

def check_data_builder_and_model():
    net = Resnet18().cuda()
    for split in ["train", "valid"]:
        dataloader = build_data('./data/', 2, split, num_worker=1, valid_block=25)
        images, labels = next(iter(dataloader))
        for i in images:
            show(transforms.ToPILImage()(i))
        
        # (bs, 3, 256, 256*N) -> (N*bs, 3, 256, 256)
        x = widthN_to_bsN(images, 25)
        for i in torch.cat((x[:25][:2], x[25:][:2])):
            show(transforms.ToPILImage()(i))
        
        # (N*bs, 3, 256, 256) -> (bs, 3, 256, 256*N)
        x = bsN_to_widthN(x, 25)
        for i in x:
            show(transforms.ToPILImage()(i))

        out = net(images.cuda())
        print(out.shape)


check_data_builder_and_model()