### Run this before making predictions

In [None]:
import argparse

class Args(argparse.Namespace):
    # General settings
    annotation_path = './data/annotations/all_annotations.json'
    images_path = './data/images'
    n_to_visualize = 0 
    match_prompt_img=False
    timeseries = None

    # Experiment settings
    model = 'SegGPT'
    location = '1'
    batch_size = 5
    remove_posthoc = True
    prompt_imgs = [4]
    remove_input = False
    n_patches = 1

    # SegGPT
    binary_mask_cutoff = 0.8

    # PerSAM
    min_feature_sim = 0.15
    persam_training = False

    # YOLO
    yolo_test_path = None
    yolo_model_path = './models/pretrained_yolo_models/best.pt'
    conf_threshold = 0.25

args=Args()

### Dataset class balance

#### Mask counts

In [None]:
import os
from collections import Counter
from tqdm import tqdm
import pandas as pd

from dataset import get_gt_masks

label_counter = {}
for location in tqdm(os.listdir('./data/images')):
    all_labels = []
    image_folder = f'./data/images/{location}'
    annotations = get_gt_masks('./data/annotations/all_annotations.json', image_folder)

    for file in os.listdir(image_folder):

        _, labels = annotations[file]
        all_labels += labels
    
    label_counter[location] = Counter(all_labels)

label_counter

In [None]:
# Get counters in correct format
labels = ['floating_trash_in_system', 'floating_trash_outside_system', 'water', 'barrier']
counter_vis = {}
for location, counter in label_counter.items():
    for l in labels:
        if l not in counter.keys():
            counter[l] = 0

    counter_vis[location] = {
        'In system' : counter['floating_trash_in_system'],
        'Out system' : counter['floating_trash_outside_system'],
        'Barrier' : counter['barrier'],
        'Water' : counter['water']
    }

# Collect counts across locations
overall_counter = {x: 0 for x in counter_vis['1'].keys()}
for counter in counter_vis.values():
    for key in overall_counter.keys():
        overall_counter[key] += counter[key]    

counter_vis['Overall'] = overall_counter
df = pd.DataFrame.from_dict(counter_vis).T
df['total'] = df.sum(axis=1)
for c in df.columns:
    if c == 'total':
        continue

    df[c] = df[c]*100/df['total']
df.round(1)

#### Pixel counts

In [None]:
import os
from collections import Counter
from tqdm import tqdm
import torch
import numpy as np

from dataset import get_gt_masks

full_labels = ['floating_trash_in_system', 'floating_trash_outside_system', 'barrier', 'water']
pixel_counter = {}
for location in tqdm(os.listdir('./data/imagese')):
    image_folder = f'./data/images/{location}'
    annotations = get_gt_masks('./data/annotations/all_annotations.json', image_folder)

    pixel_counter[location] = {x: 0 for x in full_labels}

    for file in os.listdir(image_folder):
        masks, labels = annotations[file]
        labels = np.array(labels)

        for l in full_labels:
            class_masks = masks[labels == l]
            pixel_counter[location][l] += torch.sum(class_masks).item()

pixel_counter

In [None]:
# Collect counts across locations
overall_counter = {x: 0 for x in full_labels}
for label in overall_counter.keys():   
    for location in pixel_counter.keys():
        overall_counter[label] += pixel_counter[location][label]
pixel_counter['Overall'] = overall_counter

df = pd.DataFrame.from_dict(pixel_counter).T
df['total'] = df.sum(axis=1)
for c in df.columns:
    if c == 'total':
        continue

    df[c] = df[c]*100/df['total']
df.rename(columns={'floating_trash_in_system': 'in-system', 'floating_trash_outside_system': 'out-system'})

### Bin GT masks

In [None]:
import json
import os
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import torch
from dataset import get_gt_masks
from torchvision.transforms.functional import resize

location_bins = {
'1' : np.array([0.125, 0.25])*1e6,
'2' : np.array([0.1, 0.2])*1e6,
'3' : np.array([0.1, 0.2])*1e6,
'4' : np.array([0.15, 0.3])*1e6,
'5' : np.array([0.005, 0.02])*1e6,
'6' : np.array([0.15, 0.3])*1e6,
}
for gt_file in os.listdir('./data/combined_gt_masks'):
    location = gt_file.replace('.pt', '')
    gt_mask = torch.load(f'./data/combined_gt_masks/{gt_file}')
    total_px = torch.sum(gt_mask)

    print(location, [round((x*100/total_px).item(),1) for x in location_bins[location]])

In [None]:
for location in os.listdir('data/imagese'):
    location_folder = './data/imagese/' + location
    gt_masks = get_gt_masks('./data/annotations/all_annotations.json', location_folder)

    gt_mask_sizes = []
    for file in os.listdir(location_folder):
        masks, labels = gt_masks[file]

        in_system_ids = np.where(np.array(labels) == 'floating_trash_in_system')
        masks = masks[in_system_ids]
        
        if location in ['5', '6']:
            masks = resize(masks, (1944, 2592))

        mask_size = torch.sum(masks)
        gt_mask_sizes.append(mask_size.item())

    gt_mask_sizes = np.array(gt_mask_sizes)
    small = np.sum(gt_mask_sizes <= location_bins[location][0])
    medium = np.sum((gt_mask_sizes > location_bins[location][0]) & (gt_mask_sizes <= location_bins[location][1]))
    large = np.sum(gt_mask_sizes > location_bins[location][1])

    print(small, medium, large)

    fig, axs = plt.subplots(1,2, figsize=(10,5))
    axs[0].hist(gt_mask_sizes)
    axs[1].bar(['small', 'medium', 'large'], [small, medium, large])
    plt.title(location)
    plt.show()

### Create post-hoc removal masks

In [None]:
from torchvision.io import read_image
from visualization import draw_ann

for location in os.listdir('./data/images'):
    if location != '5':
        continue

    image_folder = f'./data/images/{location}'
    annotations = get_gt_masks('./data/annotations/all_annotations.json', image_folder)

    all_gt_masks = []
    for masks, _ in tqdm(annotations.values()):
        masks = resize(masks, (1944, 2592))
        masks = torch.sum(masks, dim=0).unsqueeze(0)
        all_gt_masks.append(masks)

    combined_masks = torch.sum(torch.concat(all_gt_masks, dim=0), dim=0)
    combined_masks = combined_masks > 0

    first_img = os.listdir(image_folder)[0]
    image = read_image(f'{image_folder}/{first_img}')
    image = resize(image, (1944,2592))
    annotated = draw_ann(image, combined_masks.unsqueeze(0), int_colors=[[255, 0, 110]])

    plt.imshow(annotated)
    plt.axis('off')
    plt.show()

    torch.save(combined_masks, f'./data/combined_gt_masks/{location}.pt')

### Visualize prompt images and masks

In [None]:
import matplotlib.pyplot as plt
from torchvision.io import read_image
from torchvision.transforms.functional import resize, to_pil_image
import torch

from dataset import get_gt_masks
from visualization import draw_ann
from utils import get_int_from_label
from dataset import finetuning_images

mask_colors = [[255, 0, 110], [131, 56, 236]]
imgs = {x: [] for x in finetuning_images.keys()}

for location in finetuning_images.keys():
    image_folder = f'./data/images/{location}'
    annotations = get_gt_masks('./data/annotations/all_annotations.json', image_folder)

    for i, (img_id, mask_ids)  in enumerate(finetuning_images[location].items()):
        img_path = f'{image_folder}/{img_id}'
        original_image = read_image(img_path) 
        masks, labels = annotations[img_id]
        
        # Resize images from two locations
        if location in ['5', '6']:
            original_image = resize(original_image, (1944, 2592))
            masks = resize(masks, (1944, 2592))

        # Go from string to int labels
        labels = torch.Tensor([get_int_from_label(l) for l in labels])
        in_system = masks[labels==0] > 0

        selected_masks = in_system[mask_ids]
        annotated = draw_ann(original_image, selected_masks, mask_colors)
        imgs[location].append(annotated)

In [None]:
fig, axs = plt.subplots(6, 5, figsize=(20,20))
for i, loc in enumerate(imgs.keys()):
    ax = axs[i]
    [a.axis('off') for a in ax]

    for j, img in enumerate(imgs[loc]):
        ax[j].imshow(img)

    ax[2].set_title(loc.replace('_', ' '))
plt.show()

### Timeseries 

In [None]:
import torch
import os
from datetime import datetime
import pandas as pd
from visualization import draw_ann
from tqdm import tqdm
from main import main
from tqdm import tqdm
import matplotlib.pyplot as plt
from PIL import ImageDraw

In [None]:
def get_dt(string):
    date = string.split('_')[1]
    time = string.split('_')[2]
    dt = f'{date[:4]} {date[4:6]} {date[6:]} {time[:2]} {time[2:4]}'
    dt = datetime.strptime(dt, '%Y %m %d %H %M')
    return dt 

def create_timeseries(folder, model, gif=False, graph=False):
    args.timeseries = folder
    args.location = args.timeseries.split('/')[-1].split('_')[0]
    args.model = model

    assert model in ['SegGPT', 'YOLO']

    if model == 'SegGPT':
        args.binary_mask_cutoff = 0.8 if args.location == '1' else 0.95
        args.prompt_imgs = [4] if args.location == '1' else [3]
        args.yolo_model_path = None

    if model == 'YOLO':
        args.binary_mask_cutoff = 0.8
        args.prompt_imgs = []
        args.yolo_model_path = './models/pretrained_yolo_models/best.pt'

    # Make predictions
    output = main(args)

    # Create gif from timeseries data
    if gif:
        images = []
        for o in tqdm(output):
            img, mask, id = o['images'], o['predicted_masks'], o['img_id']

            # Annotate mask
            mask = torch.sum(mask, dim=0).unsqueeze(0) > 0
            ann = draw_ann(img, mask, int_colors=[[255, 0, 110]])

            # Put time
            dt_string = datetime.strftime(get_dt(id), '%Y-%m-%d %H:%M')
            ann_draw = ImageDraw.Draw(ann)
            ann_draw.text((0,0), dt_string, fill=(0,0,0), font_size=100)
            ann = ann_draw._image
            images.append(ann)
        images[0].save(f'gifs/{ts}_{model}.gif', format='GIF', append_images=images, save_all=True, duration=300, loop=0)

    # Create graph
    if graph:
        pred_pixels = [torch.sum(x['predicted_masks']) for x in output]
        dts = [get_dt(x) for x in os.listdir(args.timeseries)]
        df = pd.DataFrame.from_dict({'Time': dts, 'Predicted Pixels': [x.item() for x in pred_pixels]})
        df.plot(x='Time')
        plt.title(f'{ts.replace("_", " ")} - {args.model}')
        plt.show()

# Insert timeseries data here
timeseries = os.listdir('./data/timeseries')
for ts in timeseries:
    for model in ['YOLO', 'SegGPT']:
        create_timeseries(ts, model, gif=False, graph=True)
