In [11]:
import numpy as np
import torchvision.models as models
import sys
import os
import cv2
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torchvision.transforms import ToTensor
from PIL import Image
import gdown


In [12]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [13]:
TestSetDir="/content/drive/MyDrive/Test"
CompetitionModelPath = "/content/Team21.pth"
ResultsPath = "/content/Team21.txt"

In [23]:
#DOWNLOADING MODEL
url = 'https://drive.google.com/uc?id=1Ey25p4MW9FZjEQ3-QAWEc-D8CY89prUm'
output = CompetitionModelPath
gdown.download(url, output, quiet=False)

Downloading...
From: https://drive.google.com/uc?id=1Ey25p4MW9FZjEQ3-QAWEc-D8CY89prUm
To: /content/Team21.pth
100%|██████████| 94.4M/94.4M [00:00<00:00, 210MB/s]


'/content/Team21.pth'

PREPROCESSING

In [24]:
def haze_Removal(img):
    # The same pixels threshold for all (their histograms are alike)
    threshold = 100

    # Copy original image into a new image.
    new_img = img.copy()

    # Means we will shift by the first grey-level where there is less than 500 pixels
    # This was set by judging from results and image size

    for band_num in range(img.shape[2]):
        # notice that min, max are called argmin and argmax here (since that's what they really are.)

        img_band = img[:, :, band_num]
        hist = cv2.calcHist([img_band], [0], None, [256], [0, 256])

        # The first index (BV) where there is atleast "threshold" no. of pixels
        argmin = np.where(hist > threshold)[0][0]

        new_img_band = new_img[:, :, band_num]
        
        # To avoid shifting beyond zero
        big_vals = new_img_band > argmin
        new_img_band[big_vals] = new_img_band[big_vals] - argmin

        new_img[:, :, band_num] = new_img_band
    
    return new_img

In [25]:
# Set the path to your input and output directories
new_size = (256, 256)

testing_images = []

# Loop through all the files in the input directory

files = os.listdir(TestSetDir)
sorted_files = sorted(files, key=lambda x: int(x.split(".")[0]))

for filename in sorted_files:
    # Read the image from the input directory
    img = cv2.imread(os.path.join(TestSetDir, filename))

    # Resize the image to a desired size (e.g. 512*512)
    img = cv2.resize(img, new_size)

    # # Use bicubic interpolation to enhance the image resolution
    img = cv2.resize(img, None, fx=1, fy=1, interpolation=cv2.INTER_CUBIC)

    # Apply Haze Removal by Dark Subtraction
    img = haze_Removal(img)

    testing_images.append(img)

    # Save the processed image to the output directory
    #cv2.imwrite(os.path.join(ProcessedTestSetDir, filename), img)

Model

In [26]:
resnet50 = models.resnet50(pretrained=True)

for param in resnet50.parameters():
    param.requires_grad = False

# Replace the last layer for 2 classes only
features = resnet50.fc.in_features
resnet50.fc = nn.Linear(features, 2)

Testing

In [27]:
model = resnet50
model.load_state_dict(torch.load(CompetitionModelPath))
model.eval()

means= [0.3337701, 0.35129565, 0.36801142]
stds= [0.16881385, 0.1562263, 0.16852096]
transform = transforms.Compose(
    [
     transforms.Resize((256,256)),
     transforms.ToTensor(),
     transforms.Normalize((means[0],means[1],means[2]), (stds[0],stds[1],stds[2]))])


predictions = []
for image in testing_images:
    image = Image.fromarray(image)
    image_tensor = transform(image).unsqueeze(0)

    with torch.no_grad():
        output = model(image_tensor)
    
    probs = torch.softmax(output, dim=1)
    _, predicted_label = torch.max(probs.data, 1)

    predictions.append(predicted_label.item())

In [28]:
finalPreds = []
for pred in predictions:
  if pred == 1:
    finalPreds.append(0)
  else:
    finalPreds.append(1)

In [29]:
with open(ResultsPath, 'w') as file:
    for label in finalPreds:
        file.write(str(label) + '\n')