In [74]:
import gradio as gr
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
import dgl
from dgl.nn import GATConv
import networkx as nx
import os, sys

In [75]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
THRESHOLD = 0.75
IMAGE_MEAN = 155.5673
IMAGE_STD = 70.5983

In [76]:
class GATClassifier(nn.Module):
    def __init__(self, in_channels, emb_dims, in_dims, out_dims, num_classes):
        super().__init__()

        self.in_channels = in_channels
        self.emb_dims = emb_dims
        self.in_dims = in_dims
        self.out_dims = out_dims
        self.num_classes = num_classes

        self.adj_out_dims = 256

        self.cnn = torch.nn.Sequential(
            *(list(models.resnet50(pretrained=True).children())[:-1])
        )

        self.gatconv1 = GATConv(2048, 1024, num_heads=1)
        self.gatconv2 = GATConv(1024, 512, num_heads=2)

        self.lin1 = nn.Linear(2048 + out_dims, 1024)
        self.class1 = nn.Linear(1024, num_classes)
        self.class2 = nn.Linear(2048, num_classes)

    def forward(self, x, get_embeddings=False):
        embeddings = self.cnn(x)  # bs, deep_features_h
        deep_features = embeddings.reshape(-1, self.emb_dims)
        dfs = deep_features.clone()
        g = self.build_graph(dfs)

        x = F.relu(self.gatconv1(g, dfs).reshape(-1, 1024))
        x = self.gatconv2(g, x).sum(dim=1)

        x = torch.cat((deep_features, x), dim=1)
        x = F.leaky_relu(self.lin1(x))

        if get_embeddings:
            return x

        logits1 = self.class1(x)
        logits2 = self.class2(deep_features)

        return logits1, logits2

    def build_graph(self, deep_features):
        dfs = deep_features.clone().detach()
        z_norm = torch.linalg.norm(dfs, dim=1, keepdim=True)  # Size (n, 1).
        similarity = ((dfs @ dfs.T) / (z_norm @ z_norm.T)).T

        adjacency = torch.where(similarity > THRESHOLD, 1, 0)

        gx = nx.from_numpy_array(adjacency.cpu().numpy())
        g = dgl.from_networkx(gx).to(device)
        g = dgl.add_self_loop(g)
        return g

    def save_model(self, path):
        torch.save(self.state_dict(), path)

    def load_model(self, path):
        self.load_state_dict(torch.load(path))

In [113]:
model = GATClassifier(
        in_channels=3,
        emb_dims=2048,  # Deep features
        in_dims=256,
        out_dims=512,  # Graph features
        num_classes=7,
    ).to(device)

In [114]:
model.load_model('D:\Tesis\GNN\GNN implementations\CGAT\src\model params\model_params_5')
model.eval()

GATClassifier(
  (cnn): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(
          (0): Conv2d(64

In [115]:
LABELS = {
    0: 'cumulus',
    1: 'altocumulus',
    2: 'cirrus',
    3: 'clearsky',
    4: 'stratocumulus',
    5: 'cumulonimbus',
    6: 'mixed',
}

In [116]:
a = [1,0]
b = [0.003, 0.9]

In [117]:
{LABELS[i]: v for i,v in zip(a, b)}

{'altocumulus': 0.003, 'cumulus': 0.9}

In [118]:
def predict(img):
    #norm
    img = (img-IMAGE_MEAN)/IMAGE_STD
    img = torch.from_numpy(img).float()
    img = img.permute(2,1,0).unsqueeze(0)
    img = img.to(device)
    
    #prediction
    with torch.no_grad():
        logits, _ = model(img)
        probs = F.softmax(logits, dim=-1)

        values, indices = torch.topk(probs, 5)
        values = values.cpu().squeeze()
        indices = indices.cpu().squeeze()
      
        confidences = {LABELS[i.item()]: v.item() for i,v in zip(indices, values)}
    
    return confidences

In [119]:
interface = gr.Interface(fn=predict, 
                         inputs=gr.inputs.Image(shape=(256,256), image_mode="RGB", invert_colors=False, source="upload", tool="editor", type="numpy", label=None, optional=False), 
                         outputs='label')
interface.launch()

Running on local URL:  http://127.0.0.1:7885/

To create a public link, set `share=True` in `launch()`.


(<fastapi.applications.FastAPI at 0x1ef5dfff640>,
 'http://127.0.0.1:7885/',
 None)