In [None]:
import argparse
import time
import os
from pathlib import Path
import cv2
import torch
import torch.backends.cudnn as cudnn
from numpy import random
from random import randint
from models.experimental import attempt_load
from utils.datasets import LoadStreams, LoadImages
from utils.general import check_img_size, check_requirements, \
                check_imshow, non_max_suppression, \
                scale_coords, xyxy2xywh, strip_optimizer, set_logging, \
                increment_path
from utils.plots import plot_one_box
from utils.torch_utils import select_device, load_classifier, time_synchronized, TracedModel
from utils.download_weights import download

from Distortion_detection import Models
import matplotlib.pyplot as plt

# Parameters

In [None]:
source = 'inference/images/' #path to image/video folder or path to image or video
weights = 'runs/train/cluster/weights/best.pt'
view_img = True
imgsz = 640
augment = False
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
conf_thres = 0.25
iou_thres = 0.45
save_img = False

#Distortion detection
classify = True    
labels = ["Without noise", "Salt and pepper", "Speckle noise", "Uneven illumination"]

In [None]:
set_logging()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
half = device.type != 'cpu'  # half precision only supported on CUDA

# Load model
model = attempt_load(weights, map_location=device)  # load FP32 model
stride = int(model.stride.max())  # model stride
imgsz = check_img_size(imgsz, s=stride)  # check img_size

if half:
    model.half()  # to FP16

In [None]:
if classify:
    modelc= Models.NoisClassificationModel(n_patches=20,token_dim=512,nbr_blocks=1,nbr_heads=8,output_dim=4,with_outhead=True)
    state_dict = torch.load('./Distortion_detection/weights/best_model.pth')
    modelc.load_state_dict(state_dict)
    modelc.to(device)
    modelc.eval()

In [None]:
# Set Dataloader
dataset = LoadImages(source, img_size=imgsz, stride=stride)

# Get names and colors
names = model.module.names if hasattr(model, 'module') else model.names
colors = [[random.randint(0, 255) for _ in range(3)] for _ in names]

# Run inference
if device.type != 'cpu':
    model(torch.zeros(1, 3, imgsz, imgsz).to(device).type_as(next(model.parameters())))  # run once
old_img_w = old_img_h = imgsz
old_img_b = 1

In [None]:
t0 = time.time()

for path, img, im0s, vid_cap in dataset:
    img = torch.from_numpy(img).to(device)
    img = img.half() if half else img.float()  # uint8 to fp16/32
    img /= 255.0  # 0 - 255 to 0.0 - 1.0
    if img.ndimension() == 3:
        img = img.unsqueeze(0)
    # Warmup
    if device.type != 'cpu' and (old_img_b != img.shape[0] or old_img_h != img.shape[2] or old_img_w != img.shape[3]):
        old_img_b = img.shape[0]
        old_img_h = img.shape[2]
        old_img_w = img.shape[3]
        for i in range(3):
            model(img, augment=augment)[0]

    # Inference
    t1 = time_synchronized()
    pred = model(img, augment=augment)[0]
    t2 = time_synchronized()

    # Apply NMS
    pred = non_max_suppression(pred, conf_thres, iou_thres) #, classes=classes, agnostic=agnostic_nms)
    t3 = time_synchronized()

    # Apply Classifier
    if classify:
        noisyimg = im0s.copy()
        noisyimg = cv2.cvtColor(noisyimg,cv2.COLOR_BGR2GRAY) 
        noisyimg = torch.tensor(noisyimg).to(device) / 255.0
        noisyimg = noisyimg.unsqueeze(0)
        noisyimg = noisyimg.unsqueeze(0)
        outputs = modelc(noisyimg)
        _, pred_distorsion = torch.max(outputs, 1)
        pred_distorsion = pred_distorsion.cpu().numpy()[0]
            

    # Process detections
    for i, det in enumerate(pred):  # detections per image
        p, s, im0, frame = path, '', im0s, getattr(dataset, 'frame', 0)

        
        gn = torch.tensor(im0.shape)[[1, 0, 1, 0]]  # normalization gain whwh
        if len(det):
            # Rescale boxes from img_size to im0 size
            det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round()

            # Print results
            for c in det[:, -1].unique():
                n = (det[:, -1] == c).sum()  # detections per class
                s += f"{n} {names[int(c)]}{'s' * (n > 1)}, "  # add to string

            # Write results
            for *xyxy, conf, cls in reversed(det):

                # Add bbox to image
                if view_img: 
                    label = f'{names[int(cls)]} {conf:.2f}'
                    plot_one_box(xyxy, im0, label=label, color=colors[int(cls)], line_thickness=1)
                        
        # Print time (inference + NMS)
        print(f'{s}Done. ({(1E3 * (t2 - t1)):.1f}ms) Inference, ({(1E3 * (t3 - t2)):.1f}ms) NMS')

        # Stream results
        if view_img:
            %matplotlib inline
            plt.figure(figsize=(8,8))
            plt.axis('off')
            plt.imshow(im0)
            plt.show()
            print("Type of the distortion: ", labels[pred_distorsion])