In [1]:
import os
import glob
import torchvision
import tifffile as ti
import numpy as np
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.ops import nms
from torch_snippets import *
from PIL import Image
import plotly.express as px
import pandas as pd

In [2]:
model_file = 'diatom-model_02.pt'
MODEL_ROOT = './model/'
TEMP_DIR = './temp/'
VIDEO_DIR = './videoFiles/'
LABEL_DIR = './OutputLabels/'
MAX_DIMENSION = 50
MIN_DIMENSION = 10

In [3]:
## Load the model
num_classes = 2
target2label = {1: 'diatom', 0: 'background'}
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights='FasterRCNN_ResNet50_FPN_Weights.DEFAULT')
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
model.load_state_dict(torch.load(MODEL_ROOT + model_file))
modelDetails = model.eval()

In [4]:
def preprocess_image(img):
    #img = torch.tensor(img, dtype=torch.float32).permute(2,0,1)
    img = torch.tensor(img, dtype=torch.float64).permute(2,0,1)
    return img.to(device).float()

In [5]:
def decode_output(output):
    'convert tensors to numpy arrays'
    bbs = output['boxes'].cpu().detach().numpy().astype(np.uint16)
    labels = np.array([target2label[i] for i in output['labels'].cpu().detach().numpy()])
    confs = output['scores'].cpu().detach().numpy()
    ixs = nms(torch.tensor(bbs.astype(np.float32)), torch.tensor(confs), 0.05)
    bbs, confs, labels = [tensor[ixs] for tensor in [bbs, confs, labels]]

    if len(ixs) == 1:
        bbs, confs, labels = [np.array([tensor]) for tensor in [bbs, confs, labels]]
    return bbs.tolist(), confs.tolist(), labels.tolist()

In [6]:
batch_size = 61
def getImagesFromVideo(vid, name):
    im = ti.imread(vid)
    df = pd.DataFrame(columns=['timestamp', 'x0', 'y0', 'x1', 'y1'])
    count = 0
    for i in im:
        fig = px.imshow(i, binary_string=True)
        fig.update_xaxes(visible=False)
        fig.update_yaxes(visible=False)
        fig.write_image(TEMP_DIR + name + "_"+ str(count) + ".png", height=i.shape[0], width=i.shape[1])
        count = count + 1
        pass
    files = glob.glob(TEMP_DIR + '/*.png')
    img_list = find(name, files)
    veri_img = []
    for i in img_list:
        veri = Image.open(i).convert("RGB")
        veri = np.array(veri.resize((1392,1040), resample=Image.Resampling.BILINEAR))/255.
        veri = preprocess_image(veri)
        veri_img.append(veri)
        os.remove(i)
        pass
    print(f'veri_img: {len(veri_img)}')
    for i in range(int(len(veri_img) / batch_size) + 1):
        start = i * batch_size
        end = (i + 1) * batch_size
        if start >= len(veri_img):
            break
            pass
        if end > len(veri_img):
            end = -1 # len(veri_img)
            pass
        print(f'start: {start} end: {end}', end='\t')
        outputs = model(veri_img[start:end])
        for ix, output in enumerate(outputs):
            bbs, confs, labels = decode_output(output)
            filtered_bbs = []
            for b in bbs:
                if abs(b[2] - b[0]) < MAX_DIMENSION and abs(b[3] - b[1]) < MAX_DIMENSION and (abs(b[2] - b[0]) > MIN_DIMENSION or abs(b[3] - b[1]) > MIN_DIMENSION):
                    temp_df = pd.DataFrame([{'timestamp': ix + start, 'x0': b[0], 'y0': b[1], 'x1': b[2], 'y1': b[3]}])
                    df = pd.concat([df, temp_df], ignore_index=True)
                    filtered_bbs.append(b)
                    pass
                pass
            if (ix + start) % 45 == 0: # 24 == 0:
                show(veri_img[ix + start].cpu().permute(1,2,0), bbs=bbs)#, bbs=filtered_bbs, texts=labels, sz=15)
                show(veri_img[ix + start].cpu().permute(1,2,0), bbs=filtered_bbs)
                pass
            pass
        pass
    df.to_csv(LABEL_DIR + name + '.csv', index=False)
    pass


SyntaxError: expected ':' (200046067.py, line 39)

In [None]:
getImagesFromVideo(VIDEO_DIR + 'agar 1.tif', 'agar 1')

In [None]:
getImagesFromVideo(VIDEO_DIR + 'agar 2.tif', 'agar 2')

In [None]:
getImagesFromVideo(VIDEO_DIR + 'agar sw 1 (unmasked).tif', 'agar sw 1 (unmasked)')

In [None]:
getImagesFromVideo(VIDEO_DIR + 'agar sw 2.tif', 'agar sw 2')

In [None]:
getImagesFromVideo(VIDEO_DIR + 'agar sw surface 1.tif', 'agar sw surface 1')

In [None]:
getImagesFromVideo(VIDEO_DIR + 'agar sw surface 2.tif', 'agar sw surface 2')

In [None]:
getImagesFromVideo(VIDEO_DIR + 'Movie 16 5x obj.tif', 'Movie 16 5x obj')

In [None]:
getImagesFromVideo(VIDEO_DIR + 'Movie 29 5x obj 30 degrees.tif', 'Movie 29 5x obj 30 degrees')

In [None]:
getImagesFromVideo(VIDEO_DIR + 'Movie 7 5x obj cropped.tif', 'Movie 7 5x obj cropped')