In [1]:
import numpy as np
import os
import pandas as pd
import torch
import torch.nn as nn
import albumentations as A
import albumentations.pytorch
import cv2
import sys
import random
import csv
import matplotlib.pyplot as plt
sys.path.append('../')

from PIL import Image, ImageFile
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from torch.utils.data.sampler import Sampler
import torch.optim as optim

from dataset import LbpDataset, train_transforms, val_transforms, test_transforms, get_indices
from model import LBPModel, CNNModel
from loss import LBPloss
from visualize import visualize

from tqdm import tqdm
from resnet import resnet18, resnet12



In [2]:
# m = nn.AdaptiveAvgPool2d((1,1))
# input = torch.randn(1, 64, 8, 9)
# output = m(input)
# output.size()

In [3]:
# model = resnet12(pretrained=False)
# model.eval()

In [4]:
df = pd.read_csv('../data/df.csv')
print(df.shape)
df['label_id'] = df.label.apply( lambda x : 0.)
df_data = df.groupby('path')
def get_data(img_id):
    if img_id not in df_data.groups:
        return dict(image_id=img_id, boxes=list())
    
    data  = df_data.get_group(img_id)
#     boxes = data['bbox'].values
    boxes = data[['xmin', 'ymin', 'w', 'h', 'label_id']].values
#     labels = data['label'].values
    return dict(image_id = img_id, boxes = boxes)
#     return dict(image_id = img_id, boxes = boxes, labels=labels)

train_list = [get_data(img_id) for img_id in df.path.unique()]
print(len(train_list))
# df.head()

train_list[0]

(1555, 12)
1214


{'image_id': 'patch_images/2021.01.08/LBC141-20210105(1)/LBC141-20210105(1)_1001.png',
 'boxes': array([[1558., 1603.,   96.,   73.,    0.],
        [1452., 1263.,   82.,   94.,    0.]])}

In [5]:
path = '/home/Dataset/scl/patch_images/2021.01.06/LBC24-20210102(1)/'
file_list = ['patch_images/2021.01.06/LBC24-20210102(1)/' + d for d in os.listdir(path)]
file_list[:2]

test_list = [get_data(img_id) for img_id in file_list]
test_list[:2]

[{'image_id': 'patch_images/2021.01.06/LBC24-20210102(1)/LBC24-20210102(1)_1160.png',
  'boxes': []},
 {'image_id': 'patch_images/2021.01.06/LBC24-20210102(1)/LBC24-20210102(1)_1817.png',
  'boxes': []}]

In [6]:
BATCH_SIZE = 8
train_dataset = LbpDataset(
    train_list,
    transform=train_transforms,
)
train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=BATCH_SIZE,
    num_workers=8,
#     pin_memory=config.PIN_MEMORY,
    shuffle=True,
    drop_last=False,
)

test_dataset = LbpDataset(
    test_list,
    transform=val_transforms,
)
test_loader = DataLoader(
    dataset=test_dataset,
    batch_size=2,
    num_workers=2,
#     pin_memory=config.PIN_MEMORY,
    shuffle=True,
    drop_last=False,
)

In [7]:
image, cell_iou, targets, path = next(iter(train_dataset))
print(image.shape)
print(cell_iou.shape)
print(targets.shape)
# train_dataset.anchors
image, cell_iou, targets, path = next(iter(train_loader))
print(image.shape)
print(cell_iou.shape)
print(targets.shape)

(2048, 2048, 3)
torch.Size([3721, 1])
torch.Size([3721, 1])
torch.Size([8, 2048, 2048, 3])
torch.Size([8, 3721, 1])
torch.Size([8, 3721, 1])


In [8]:
# a = np.array([1,5,3,2])
# b = np.array([2,1,4,4])
# c = np.array([6,3,1,2])
# t_list = []
# t_list.append(a)
# t_list.append(b)
# t_list.append(c)
# np.max(t_list, axis=0)

In [9]:
device = torch.device('cuda')
# model = CNNModel().to(device)
model = resnet12(pretrained=False).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.0005, weight_decay=1e-5 )
loss_fn = LBPloss(device).to(device)
# scaler = torch.cuda.amp.GradScaler()

In [10]:
# # torch.randperm(0)
# a = torch.tensor([1,10,3,5,7,11,32,22])
# index = [3,5,7]
# # for b in a[index]
# a[index]

In [11]:
# a = torch.tensor([  13,  202,  388,  426,  687, 1224, 1357, 1403, 1465, 1723, 1737, 1822,
#         2121, 3084, 3190, 3207, 3357, 3396, 3518, 3640])
# list(a.int())
# a.tolist()

# torch.Size([2, 1, 128, 128, 3])
# a = torch.randn(2,1,8,8,3)
# b = torch.randn(2,1,8,8,3)
# b1 = torch.randn(2,1,8,8,3)
# c = [a, b, b1]
# index = [1,2]
# c[index]

In [None]:
epochs = 10
for epoch in (range(epochs)) :
    batch_losses = []
    loop = tqdm(train_loader, leave=True)
    count = 0
    for images, iou, targets, path in loop :

        batch_size, gride_size, _ = iou.shape
        images = images.permute(0,3,1,2).to(device)

        indices, iou, targets = get_indices(iou, targets)
        labels = torch.cat([iou, targets], dim=-1)
#         print(labels.shape)
        
        outputs = model(images, indices)
#         print(outputs.shape)

        labels = labels.to(device)

#         print(outputs.shape)
#         print(labels.shape)
        loss, cell_loss = loss_fn(outputs, labels)
#         print(loss)
#         print(cell_loss)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        batch_losses.append(loss.item())

        mean_loss = sum(batch_losses) / len(batch_losses)
        loop.set_postfix(loss=mean_loss) 
        count += 1


100%|██████████| 152/152 [00:52<00:00,  2.87it/s, loss=-1.18]
100%|██████████| 152/152 [00:52<00:00,  2.89it/s, loss=-3.57]
 66%|██████▌   | 100/152 [00:36<00:18,  2.86it/s, loss=-4.12]

In [None]:
# positive_indices = torch.ge(iou, 0.5)
# print(positive_indices)
print(len(torch.where(iou[0,:,0] > 0.8)[0]))
print(len(torch.where((iou[0,:,0] > 0.0) & (iou[0,:,0] < 0.8))[0]))

print(len(torch.where(targets[0,:,0] > 0.8)[0]))
print(len(torch.where((targets[0,:,0] > 0.0) & (targets[0,:,0] < 0.8))[0]))

normal_cell = torch.where(iou[0,:,0] > 0.8)[0]
normal_cell_not = torch.where((iou[0,:,0] > 0.0) & (iou[0,:,0] < 0.8))[0]

abnormal_cell = torch.where(targets[0,:,0] > 0.8)[0]
abnormal_cell_not = torch.where((targets[0,:,0] > 0.0) & (targets[0,:,0] < 0.8))[0]


normal_cell_indices = torch.randperm(len(normal_cell))[:10]
normal_cell_not_indices = torch.randperm(len(normal_cell_not))[:10]
abnormal_cell_indices = torch.randperm(len(abnormal_cell))[:10]
abnormal_cell_not_indices = torch.randperm(len(abnormal_cell_not))[:10]


In [None]:
ncell = (normal_cell[normal_cell_indices])
ncell_not = (normal_cell_not[normal_cell_not_indices])
abcell = (abnormal_cell[abnormal_cell_indices])
abcell_not = (abnormal_cell_not[abnormal_cell_not_indices])


In [None]:
indice, _ = torch.sort(torch.cat([ncell, ncell_not, 
           abcell, abcell_not]), dim=-1)
indice

In [None]:
bbox

In [None]:
abnormal_cell = torch.where(targets[0,:,0] > 0.8)[0]
abnormal_cell

In [None]:
normal_cell_indices = torch.randperm(len(normal_cell))[:10]
normal_cell_indices

In [None]:
torch.randperm(len([1,2,3,4,5,6,7,8,9,0,11,12,100]))

In [None]:
import cv2
import numpy as np
import matplotlib.pyplot as plt
import torch

BOX_COLOR = (255, 0, 0) # Red
TEXT_COLOR = (255, 255, 255) # White

def visualize_bbox(img, bbox, color=BOX_COLOR, thickness=2):
    """Visualizes a single bounding box on the image"""
    x_min, y_min, x_max, y_max = list(map(int, bbox))
#     print(bbox)
#     x_min, y_min, x_max, y_max = list(map(round, bbox))
#     print((int(x_min), int(y_min)), (int(x_max), int(y_max)))

    img = cv2.rectangle(img, (int(x_min), int(y_min)), (int(x_max), int(y_max)), color=BOX_COLOR, thickness=thickness)
    return img

def visualize(image, bboxes):
    img = image.copy()
    print(img.shape)
#     img = image.clone().detach()
    for bbox in (bboxes):
#         print(bbox)
        img = visualize_bbox(img, bbox)
    plt.figure(figsize=(10, 10))
    plt.axis('off')
    plt.imshow(img)

In [None]:
# images, _, _ = next(iter(test_loader))
images, iou, _ , path = next(iter(train_loader))
print(images.shape)
images = images.permute(0,3,1,2).to(device)
# image = image.to(device)
outputs = model(images)

In [None]:
# print(outputs.shape)
# iou[0] > 0.8

In [None]:
# (outputs.view(2,15,15,2)[1,:,:,0] >= 0.999).sum()
print((outputs.view(BATCH_SIZE,15,15,2)[1,:,:,0] > 0.9).sum() )
print((outputs.view(BATCH_SIZE,15,15,2)[1,:,:,1] > 0.9).sum() )
print(torch.max(outputs.view(BATCH_SIZE,15,15,2)[0,:,:,1]))
print(torch.max(outputs.view(BATCH_SIZE,15,15,2)[1,:,:,1]))
print(torch.max(outputs.view(BATCH_SIZE,15,15,2)[2,:,:,1]))


In [None]:
# print(outputs.view(2,15,15,2)[1,:,:,0])
# print(outputs.view(2,15,15,2)[1,:,:,1])
# outputs.view(2,15,15,2)[1,:,:,0] > 0.5
# torch.ge(outputs.view(2,15,15,2)[1,:,:,0], 0.5)
# positive_indices[positive_indices==True]
a, b = (outputs.view(BATCH_SIZE,15,15,2)[0,:,:,0] >= 0.9).cpu().detach().nonzero(as_tuple=True)
# a, b = (outputs.view(BATCH_SIZE,15,15,2)[1,:,:,1] >= 0.15).cpu().detach().nonzero(as_tuple=True)
# a, b = (iou.view(BATCH_SIZE,15,15)[0,:,:] >= 0.7).cpu().detach().nonzero(as_tuple=True)
bboxes = []

for x, y in zip(a, b) :
#     print(int(x), int(y))
    xmin = int(x) * 32
    xmax = xmin + 64
    ymin = int(y) * 32
    ymax = ymin + 64
    bboxes.append([xmin, ymin, xmax, ymax])
bboxes    

In [None]:
# outputs.view(2,15,15,2)[1,:,:,1] > 0.5

In [None]:
# images[0].permute(1,2,0)

In [None]:
visualize(images[1].permute(2,1,0).cpu().detach().numpy(), bboxes)

In [None]:
plt.figure(figsize=(10, 10))
plt.imshow(images[0].permute(2,1,0).cpu().detach().numpy())