In [1]:
import torch
import argparse
# import clip
import open_clip
import json
from PIL import Image
import torchvision
import torchvision.transforms as tvt
from torch.utils.data import DataLoader, Dataset
from matplotlib import pyplot as plt
import os
import glob
import pandas as pd
import numpy as np
from tqdm import tqdm
import logging
from contextlib import suppress
import torch.nn.functional as F
from sklearn.metrics import classification_report, balanced_accuracy_score

## DataLoader

In [None]:
class MyDataset(Dataset):
    def __init__(self, root, transform, json_path):
        self.root = root
        new_class_list = []
        class_indices = {}
        for cla in os.listdir(self.root):
            new_class_list.append(cla)
        for i, cla in enumerate(new_class_list):
            class_indices[str(i)] = cla
        if json_path:
            with open(json_path, "w") as json_file:
                json.dump(class_indices, json_file)
        self.class_list = new_class_list
        self.transform = transform

        self.path_list = []  # a list of image paths
        self.img_list = []   # a list of images paths with corresponding labels

        for file in self.class_list:
            file_path = os.path.join(self.root, file, "")
            self.path_list.append(file_path)
            file_label = self.class_list.index(file)
            pattern = file_path + '*'
            for img in glob.glob(pattern):
                img_list = [file_label, img]
                # add image label and path to a image list
                self.img_list.append(img_list)

    def __len__(self):
        return len(self.img_list)

    def __getitem__(self, index):
        img_label = self.img_list[index][0]
        img = Image.open(self.img_list[index][1])
        img_transformed = self.transform(img)
        return img_transformed, img_label

## Confusion Matrix

In [None]:
class ConfusionMatrix(object):


    def __init__(self, num_classes: int, labels: list):
        self.matrix = np.zeros((num_classes, num_classes))
        self.num_classes = num_classes
        # label_list = []
        # for i, label in enumerate(labels):
        #     if i%3 == 0:
        #         label_list.append(label)

        self.labels = labels

    def update(self, preds, labels):
        # add the result of each iteration
        for p, t in zip(preds, labels):
            self.matrix[p//3, t] += 1


    def summary(self):
        # calculate accuracy
        sum_TP = 0
        n = np.sum(self.matrix)
        for i in range(self.num_classes):
            sum_TP += self.matrix[i, i]
        acc = sum_TP / n
        print("the model accuracy is ", acc)
        return str(acc)

    def plot(self):
        matrix = self.matrix
        # df = pd.DataFrame(matrix, columns=self.labels, index=self.labels)
        # df.to_csv("cm_aid.cvs")
        print(matrix)
        plt.figure(figsize=(18, 14))
        plt.imshow(matrix, cmap=plt.cm.Blues)

     
        plt.xticks(range(self.num_classes), self.labels, rotation=90)
        plt.yticks(range(self.num_classes), self.labels)
        #colorbar
        plt.colorbar()
        plt.xlabel('True Labels')
        plt.ylabel('Predicted Labels')
        plt.title('Confusion matrix (acc='+self.summary()+')')

        #annotation
        thresh = matrix.max() / 2
        for x in range(self.num_classes):
            for y in range(self.num_classes):
                info = int(matrix[y, x])
                plt.text(x, y, info,
                         verticalalignment='center',
                         horizontalalignment='center',
                         color="white" if info > thresh else "black")
        plt.tight_layout()
        plt.savefig("patternnet_confusion_matrix.jpg", dpi=300)
        plt.show()

## Zero-shot Classification

In [None]:
# Code adapted from https://github.com/LAION-AI/CLIP_benchmark/blob/main/clip_benchmark/metrics/zeroshot_classification.py

def zero_shot_classifier(model, tokenizer, classnames, templates, prompts, device, amp=True, cupl=False):

    autocast = torch.cuda.amp.autocast if amp else suppress
    with torch.no_grad(), autocast():
        zeroshot_weights = []
        for classname in tqdm(prompts.keys()):
            if cupl:
                texts = templates[classname]
            else:
                texts = [template.format(c=classname) for template in templates]
            texts = tokenizer(texts).to(device)  # tokenize
            class_embeddings = model.encode_text(texts)
            class_embedding = F.normalize(class_embeddings, dim=-1).mean(dim=0)
            class_embedding /= class_embedding.norm()
            zeroshot_weights.append(class_embedding)
        zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(device)
    return zeroshot_weights


def accuracy(output, target, topk=(1,)):
    pred = output.topk(max(topk), 1, True, True)[1].t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))
    n = len(target)
    return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy())/ n for k in topk]

    

def run(model, classifier, dataloader, class_num, class_indices, device, amp = True):
    autocast = torch.cuda.amp.autocast if amp else suppress
    labels = [label for _,label in class_indices.items()]
    confusion = ConfusionMatrix(num_classes=class_num, labels=labels)
    pred = []
    groundTruth = []
    nb = 0

    with torch.no_grad():
        top1, top5, n = 0., 0., 0.
        for images, target in tqdm(dataloader):
            images = images.to(device)
            target = target.to(device)

            with autocast():
                # predict
                image_features = model.encode_image(images)
                image_features = F.normalize(image_features, dim=-1)
                logits = (100. * image_features @ classifier)

            
            groundTruth.append(target.cpu())
            pred.append(logits.float().cpu())
            pred1 = logits.topk(1, 1, True, True)[1].t()
            confusion.update(pred1.to("cpu").numpy(), target.to("cpu").numpy())
    
    pred = torch.cat(pred)
    true = torch.cat(groundTruth)
    confusion.plot() 
    return pred, true


def average_precision_per_class(scores, targets):

    ap = torch.zeros(scores.size(1))
    rg = torch.arange(1, scores.size(0) + 1).float()
    # compute average precision for each class
    for k in range(scores.size(1)):
        # sort scores
        scores_k = scores[:, k]
        targets_k = targets[:, k]
        _, sortind = torch.sort(scores_k, 0, True)
        truth = targets_k[sortind]
        tp = truth.float().cumsum(0)
        # compute precision curve
        precision = tp.div(rg)
        # compute average precision
        ap[k] = precision[truth.bool()].sum() / max(float(truth.sum()), 1)
    return ap

def evaluate(model, dataloader, tokenizer, classnames, templates, prompts, class_indices, device, amp=True, verbose=False, cupl=False, save_clf=None, load_clfs=[]):
    
    labels = [label for _,label in class_indices.items()]
    confusion = ConfusionMatrix(num_classes=len(classnames), labels=labels)
    if len(load_clfs) > 0:
        n = len(load_clfs)
        classifier = torch.load(load_clfs[0], map_location='cpu') / n
        for i in range(1, n):
            classifier = classifier + torch.load(load_clfs[i], map_location='cpu') / n
        classifier = classifier.to(device)
    else:
        classifier = zero_shot_classifier(model, tokenizer, classnames, templates, prompts, device, cupl=cupl)
    
    if save_clf is not None:
        torch.save(classifier, save_clf)
        # exit() - not sure if we want to exit here or not.
    class_num = len(classnames)
    logits, target = run(model, classifier, dataloader, class_num, class_indices, device, amp=amp)
    is_multilabel = (len(target.shape) == 2)

    if is_multilabel:
        if verbose:
            print("Detected a multi-label classification dataset")
        # Multiple labels per image, multiple classes on the dataset
        ap_per_class = average_precision_per_class(logits, target)
        if verbose:
            for class_name, ap in zip(classnames, ap_per_class.tolist()):
                print(f"Class: {class_name}, AveragePrecision: {ap}")
        return {"mean_average_precision": ap_per_class.mean().item()}
    else:
        # Single label per image, multiple classes on the dataset
        # just compute accuracy and mean_per_class_recall

        pred = logits.argmax(axis=1)

        # measure accuracy
        if len(classnames) >= 3:
            acc1, acc5 = accuracy(logits, target, topk=(1, 3))
        else:
            acc1, = accuracy(logits, target, topk=(1,))
            acc5 = float("nan") 
        mean_per_class_recall = balanced_accuracy_score(target, pred)
        if verbose:
            print(classification_report(target, pred, digits=3))
        return {"acc1": acc1, "acc3": acc5, "mean_per_class_recall": mean_per_class_recall}
    
    

## Model

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='openai')
# model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k')
model.to(device)
tokenizer = open_clip.get_tokenizer('ViT-B-32')

## AID

### prompts map

In [None]:
map_to_parent = {
    "AIRPORT: A large, open field with a runway in the middle. The field is covered with grass, and there are airplanes parked on the runway. There is a terminal in the background." : "airport",
    "AIRPORT:The land cover in this picture is an airport, specifically an airport runway and surrounding area. The scene provides a comprehensive view of the airport's layout and infrastructure, including the runway, taxiways, and terminal building.": "airport",
    "AIRPORT:the image shows a large grassy field with a runway in the middle, surrounded by a variety of airplanes parked on the runway. The presence of multiple airplanes parked on the runway indicates that the airport is active and in use.": "airport",
    "BARELAND: A dry and barren landscape with little vegetation. It appears to be a rough terrain." : "bareland",
    "BARELAND: The image shows a large, empty lot with dirt and debris." : "bareland",
    "BARELAND: A construction site, which is an area where buildings or structures are being built or renovated. The image shows a large, open area with dirt and sand." : "bareland",
    "BASEBALL FIELD: The image shows a large, grassy field with a baseball diamond in the center, surrounded by a dirt infield and a grass outfield." : "baseball field",
    "BASEBALL FIELD: The image shows a baseball field with a large circular infield, surrounded by a grassy outfield." : "baseball field",
    "BASEBALL FIELD: The image shows a large grassy field with several baseball diamonds, each with a home plate, bases, and a pitcher's mound." : "baseball field",
    "BEACH: The beach is surrounded by water, and there is a sandy area with a large sand dune on the coast." : "beach",
    "BEACH: The land cover in the image is a beach with sand and water. The beach is covered in sand, and there is a body of water, likely an ocean." : "beach",
    "BEACH: The land cover in this picture is a sandy beach with a sandy dune area." : "beach",
    "BRIDGE: The land cover type in this picture is a bridge, that looks like a curved gray-line in the image." : "bridge",
    "BRIDGE: The land cover is a bridge crossing over water. There are also people on the bridge." : "bridge",
    "BRIDGE: The bridge is connecting two lands crossing the water." : "bridge",
    "CENTER: There is a large, modern building, which appears to be a white, and looks like a center." : "center",
    "CENTER: There is a large, circular building surrounded by grass and trees. The building appears to be a large, modern structure that coould be a center." : "center",
    "CENTER: A white irregular-shaped building is in the middle of the image. It is situated in an urban area." : "center",
    "CHURCH: There is a white and gray Christian churches built in a cross-shaped architectural style.": "church",
    "CHURCH: A irregular-shaped white building is situated in the middle of the picture. " : "church",
    "CHURCH: There is a cruciform-shaped church that has a orange roof in the middle of the image surronded by other orange buildings." : "church",
    "COMMERCIAL: The image shows a bird's-eye view of a city, with a large number of buildings and roads, as well as a few parking lots." : "commercial",
    "COMMERCIAL: There are many cars on the roads and a significant number of buildings visible. The cityscape is likely filled with various types of buildings, including commercial, and office spaces, as well as public spaces like parks and plazas." : "commercial",
    "COMMERCIAL: The content of the picture includes various buildings, such as hotels, office buildings, and commercial structures, as well as roads and possibly a body of water, which could be a lake or a river." : "commercial",
    "DENSE RESDENTIAL: The view is characterized by a dense network of streets, buildings, and various structures. There are numerous houses and apartments and a significant amount of urban development." : "dense resdential",
    "DENSE RESDENTIAL: It is a dense neighborhood with a large number of houses. there are also cars parked on the street. The houses are of various colors, including brown, white, and gray, and they are arranged in a grid-like pattern." : "dense resdential",
    "DENSE RESDENTIAL: The area is filled with multiple buildings, including apartments and houses. There are also cars and a pedestrian crossing in the scene. The content of the picture is focused on the urban environment, highlighting the density and diversity of the cityscape." : "dense resdential",
    "DESERT: The image features a sandy beach and the sand dunes are visible in the distance." :"desert",
    "DESERT: The view is a sandy desert, with a sandy surface and some rocky formations. The image shows a close-up view of the sandy desert, which is characterized by its sandy texture and rocky features." :"desert",
    "DESERT: The image shows a sandy desert, with a sandy texture and a sandy color." :"desert",
    "FARMLAND: The is a large, flat, and open field. The field is covered with crops and has a few trees scattered throughout." : "farmland",
    "FARMLAND: There is a large field with a river running through it. The field is covered with crops and green vegetations." : "farmland",
    "FARMLAND: There is a large, flat area of land used for agricultural purposes. The field is covered with green vegetation, which could be crops, or other types of plants." : "farmland",
    "FOREST: There is a dense forest, with a mix of trees and greenery. The image shows a close-up view of the forest, with a lot of trees and foliage, creating a lush and vibrant landscape." : "forest",
    "FOREST: The image shows a large area of forest, with a lot of trees and foliage, creating a lush and vibrant landscape." : "forest",
    "FOREST: The image shows a dense forest, consisting of a large area of trees. The forest is covered in green foliage, which indicates that it is likely a lush and healthy ecosystem." : "forest",
    "INDUSTRIAL: There is a large industrial area, with numerous buildings and warehouses. The image shows a city with a significant amount of commercial and industrial activity, as evidenced by the numerous shipping containers and large buildings." : "industrial",
    "INDUSTRIAL: The image shows a cityscape with a variety of buildings, including office buildings, factories, and warehouses." : "industrial",
    "INDUSTRIAL: It is an industrial area, with multiple factories and buildings visible. The content of the image includes various types of factories, warehouses, and other industrial structures." : "industrial",
    "MEADOW: It is a field of green grass. The image shows a large expanse of green grass. The grass is well-maintained and appears to be healthy." : "meadow",
    "MEADOW: The view is a large, open field of grass. The meadow is covered with dirt and has some dirt roads running through it." : "meadow",
    "MEADOW: The view in the image is a large, open field with grass." : "meadow",
    "MEADIUM RESDENTIAL: The image shows a close-up view of a street with some houses, some of which are red and some are brown. The houses are situated in a row, and there are trees in the background." : "medium resdential",
    "MEADIUM RESDENTIAL: The image shows a colse view of the residential area, which is a neighborhood with not much people. The houses are arranged in a grid-like pattern, with a few roads and sidewalks connecting them." : "medium resdential",
    "MEADIUM RESDENTIAL: The image shows a close view of neighborhood with few houses and buildings, as well as a park or green area. The overall content of the picture is a medium residential neighborhood." : "medium resdential",
    "MOUNTAIN: The image shows a mountainous terrain with a mix of green and brown colors. The mountainous terrain is covered with a variety of vegetation, including grass and trees, which gives it a lush and vibrant appearance." : "mountain",
    "MOUNTAIN: The view in this picture is a mountainous terrain with snow-covered peaks and valleys. The image features a large mountain range, which is covered in snow, and the landscape appears to be quite rugged and rocky." : "mountain",
    "MOUNTAIN: The view in the image is a mountainous terrain, with a large expanse of greenery covering the hills and valleys. The landscape features a mix of grassy hills and rocky outcrops, creating a diverse and visually appealing scene." : "mountain",
    "PARK: It is a park or a recreational area, featuring a large lake, a road, and a pathway. The park is designed with a complex network of roads and paths, which are likely intended for various recreational activities such as walking, jogging, or cycling." : "park",
    "PARK: The view in the image is a large park or amusement park, featuring a variety of water-based attractions, such as water slides, a water park, and a swimming pool. The park is surrounded by a city, and there are also some trees in the area." : "park",
    "PARK: The picture shows a park, with a playground. The image shows a bird's-eye view of the park, which is surrounded by buildings and a city. The park is a green space with a playground, and possibly other recreational facilities." : "park",
    "PARKING: The image shows a large parking lot with many cars parked in rows, creating a grid-like pattern. The parking lot is surrounded by buildings." : "parking",
    "PARKING: The content of the image shows a large parking lot filled with cars, with some empty spaces available. The parking lot is surrounded by buildings and other structures." : "parking",
    "PARKING: The scene is captured from an aerial view, giving a bird's-eye perspective of the parking lot and the vehicles parked within it." : "parking",
    "PLAYGROUND: The image shows a large, green soccer field with a red track surrounding it, which is likely a running track." : "playground",
    "PLAYGROUND: The image shows a large soccer field with trees in the background, and there are people playing soccer on the field." : "playground",
    "PLAYGROUND: The image shows a large, flat, and open area designed for playing soccer. The field is covered in grass, and it is surrounded by a fence, which is likely to be used for marking the boundaries of the field." : "playground",
    "POND: The image shows a large, open area with a pond in the center, surrounded by dirt and rocks." : "pond",
    "POND: The image depicts a serene and picturesque scene with a lake in the center, surrounded by dirt and rocks. The lake is situated in the middle of an open area. The area might be a park or a recreational area." : "pond",
    "POND: There is a large body of water, which appears to be a lake. The image shows a lake with a large hole in the middle, surrounded by grassy land." : "pond",
    "PORT: The image shows a harbor or docking area for boats and ships. There is a large body of water with boats docked in the marina, and there are also buildings and structures in the area." : "port",
    "PORT: The view in this picture is a body of water, specifically a large body of water with boats docked along the shore. The marina is surrounded by a city." : "port",
    "PORT: The content of the image shows a large marina with many boats docked in it. The boats are parked in a circular pattern." : "port",
    "RAILWAY STATION: The image shows a large, open area with a railway station in the center. The railway station is surrounded by a mix of residential and industrial buildings." : "railway station",
    "RAILWAY STATION: The image shows a large area with multiple train tracks, which are surrounded by buildings and other structures." : "railway station",
    "RAILWAY STATION: The image shows a map of the rail yard, with multiple train tracks and a variety of buildings, including a large warehouse." : "railway station",
    "RESORT: The image shows a view of the city from above, with a building that has a pool on its rooftop. The pool is surrounded by a deck and appears to be a popular spot for relaxation and leisure." : "resort",
    "RESORT: The image shows a large hotel complex, which is surrounded by a lush green lawn and a sandy beach. The hotel has a pool, which is visible in the image. The beach is also visible, with sand and water in the foreground." : "resort",
    "RESORT: There is a large, luxurious complex with multiple buildings, a lake, and a golf course. The resort is surrounded by a forest, which adds to its natural beauty and serene atmosphere." : "resort",
    "RIVER: It is a river with a sandy beach and some trees nearby. The river is flowing through a forest, and there is a small village or town visible in the background." : "river",
    "RIVER: The river is surronded by a grassy field. The image also shows a city or town in the background." : "river",
    "RIVER: The river is flowing through a lush, green forest, which is a common characteristic of tropical rainforests. The image shows the river's winding path, with sandy banks and a sandy bottom. The forest is dense and full of trees." : "river",
    "SCHOOL: It is a campus or university area. The image shows a large, open area with buildings, including a few tall buildings, and a grassy field. There are also trees scattered throughout the area." : "school",
    "SCHOOL: The school area is likely a campus or a neighborhood with multiple schools, as there are multiple buildings visible in the picture and various and green spaces visible." : "school",
    "SCHOOL: The image shows a school with a large field in the middle of it. The school is surrounded by buildings, and there is a stadium in the middle of the field. The image also shows a road and a parking lot." : "school",
    "SPARSE RESDENTIAL: There is only one house in the neighborhood surronded by some trees and green vegetation." : "sparse resdential",
    "SPARSE RESDENTIAL: The image shows a large area withone house situated in the middle of it." : "sparse resdential",
    "SPARSE RESDENTIAL: The image shows a bird's-eye view of the neighborhood, with only one house and a pool situated in the middle of the property." : "sparse resdential",
    "SQUARE: The image shows a large open space with a dirt circle in the center, surrounded by various construction equipment and vehicles." : "square",
    "SQUARE: The image shows a large open space with trees and a grassy area. The area is filled with trees, and there is a large grassy area in the center." : "square",
    "SQUARE: The view in this picture is a large, open park with a square in the center. The park is surrounded by buildings, and there are people enjoying the outdoor space." : "square",
    "STADIUM: The view in the image is a large, circular stadium with a grass field. The stadium is surrounded by a parking lot and a road, indicating that it is likely a sports facility." : "stadium",
    "STADIUM: The land cover in this picture is a large, open field or stadium, which is covered in green grass." : "stadium",
    "STADIUM: The image shows a large, grassy field with a baseball diamond in the center, surrounded by a stadium with seating areas for spectators." : "stadium",
    "STORAGE TANKS: The image shows a large industrial complex with multiple tanks and a railroad track running alongside it." : "storage tanks",
    "STORAGE TANKS: The view in this picture is an industrial area, specifically a petrochemical plant or refinery. The image shows a large number of large tanks, which are likely used for storing and processing various types of petroleum products."  : "storage tanks",
    "STORAGE TANKS: The view in this picture is industrial, with a large complex of buildings and structures, including oil refineries, pipelines, and storage tanks." : "storage tanks",
    "VIADUCT: The image shows a large intersection with multiple roads and highways." : "viaduct",
    "VIADUCT: The image shows a close-up view of a city intersection, with multiple roads and traffic lights." : "viaduct",
    "VIADUCT: The view in this picture is a highway, specifically a complex interchange with multiple roads and ramps. The image shows a large, busy highway with multiple lanes and interchanges." : "viaduct"
}

In [None]:
val = 'F:/course/2023summer/AID_val'
batch_size = 1
test_set = MyDataset(root = val ,transform = preprocess, json_path = 'aid.json')
test_loader = DataLoader(dataset=test_set, batch_size=batch_size,shuffle=True,num_workers=0)

In [None]:
templates = ["an aerial photo of the {c}.", 
             "an aerial photo of a {c}.",
             "an aerial photo of an {c}.",
             "an aerial photo of {c}."
             ]
with open("aid_classes_indices.json", "r") as json_file:
    json_data = json.load(json_file)

classnames = list(json_data.values())
print(classnames)

In [None]:
metrics = evaluate(model = model, dataloader = test_loader, tokenizer = tokenizer, classnames = classnames, templates = templates, prompts = map_to_parent, class_indices=json_data, device = device, amp=True, verbose=False, cupl=False, save_clf=None, load_clfs=[])
# dump = {
#     "dataset": "AID",
#     "model": "ViT-B-32",
#     "pretrained": "openai",
#     "metrics" : metrics,
# }
# with open("AID_results_hierarchy.json", "w") as f:
#         json.dump(dump, f)
print("done")