<a href="https://colab.research.google.com/github/SauravMaheshkar/ResMLP-Flax/blob/main/notebooks/Gradio.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


# 🏗 Basic Setup


---





In [1]:
%%capture
!pip install --upgrade gradio
!pip install timm -upgrade
!pip install git+https://github.com/rwightman/pytorch-image-models.git
!pip install submitit
!pip install --upgrade wandb

In [None]:
!wandb login

In [2]:
import torch
import requests
import torchvision
import torch.nn as nn
from PIL import Image
from torchvision import transforms
from timm.models.layers import trunc_normal_,  DropPath
from timm.models.vision_transformer import Mlp, PatchEmbed

torch.set_grad_enabled(False);


# 🏠 Model Architecture
---



In [4]:
class Affine(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.alpha = nn.Parameter(torch.ones(dim))
        self.beta = nn.Parameter(torch.zeros(dim))

    def forward(self, x):
        return self.alpha * x + self.beta    
    
class layers_scale_mlp_blocks(nn.Module):

    def __init__(self, dim, drop=0., drop_path=0., act_layer=nn.GELU,init_values=1e-4,num_patches = 196):
        super().__init__()
        self.norm1 = Affine(dim)
        self.attn = nn.Linear(num_patches, num_patches)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = Affine(dim)
        self.mlp = Mlp(in_features=dim, hidden_features=int(4.0 * dim), act_layer=act_layer, drop=drop)
        self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
        self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)

    def forward(self, x):
        x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x).transpose(1,2)).transpose(1,2))
        x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
        return x 


class resmlp_models(nn.Module):

    def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,drop_rate=0.,
                 Patch_layer=PatchEmbed,act_layer=nn.GELU,
                drop_path_rate=0.0,init_scale=1e-4):
        super().__init__()



        self.num_classes = num_classes
        self.num_features = self.embed_dim = embed_dim  

        self.patch_embed = Patch_layer(
                img_size=img_size, patch_size=patch_size, in_chans=int(in_chans), embed_dim=embed_dim)
        num_patches = self.patch_embed.num_patches
        dpr = [drop_path_rate for i in range(depth)]

        self.blocks = nn.ModuleList([
            layers_scale_mlp_blocks(
                dim=embed_dim,drop=drop_rate,drop_path=dpr[i],
                act_layer=act_layer,init_values=init_scale,
                num_patches=num_patches)
            for i in range(depth)])


        self.norm = Affine(embed_dim)



        self.feature_info = [dict(num_chs=embed_dim, reduction=0, module='head')]
        self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=0.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)



    def get_classifier(self):
        return self.head

    def reset_classifier(self, num_classes, global_pool=''):
        self.num_classes = num_classes
        self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()

    def forward_features(self, x):
        B = x.shape[0]

        x = self.patch_embed(x)

        for i , blk in enumerate(self.blocks):
            x  = blk(x)

        x = self.norm(x)
        x = x.mean(dim=1).reshape(B,1,-1)

        return x[:, 0]

    def forward(self, x):
        x  = self.forward_features(x)
        x = self.head(x)
        return x 


model = resmlp_models(
        patch_size=8, embed_dim=768, depth=24,
        Patch_layer=PatchEmbed,
        init_scale=1e-6)

In [5]:
checkpoint = torch.hub.load_state_dict_from_url(
            url='https://dl.fbaipublicfiles.com/deit/resmlpB_24_no_dist.pth',
            map_location="cpu", check_hash=True
        )
            
model.load_state_dict(checkpoint)
model.eval();

In [None]:
mean, std = [0.485, 0.456, 0.406],[0.229, 0.224, 0.225]    
transformations = {}
Rs_size=int(224/0.9)
transform= transforms.Compose(
        [transforms.Resize(Rs_size, interpolation=3),
         transforms.CenterCrop(224),
         transforms.ToTensor(),
         transforms.Normalize(mean, std)])

In [13]:
# Download human-readable labels for ImageNet.
response = requests.get("https://git.io/JJkYN")
labels = response.text.split("\n")

def predict(inp):
  inp = Image.fromarray(inp.astype('uint8'), 'RGB')
  inp = transform(inp).unsqueeze(0)
  with torch.no_grad():
    prediction = torch.nn.functional.softmax(model(inp), dim=-1)[0]
  return {labels[i]: float(prediction[i]) for i in range(1000)}

# Gradio Demo

---



In [None]:
import wandb
run = wandb.init(project='ResMLP', entity='sauravmaheshkar')

In [21]:
import gradio as gr

title = "ResMLP-B24/8 Demo"
article = "<p style='text-align: center'><a href='https://wandb.ai/sauravmaheshkar/ResMLP/reports/Feedforward-networks-for-image-classification--Vmlldzo4NTk3MDA'>Weights and Biases Report: Feedforward networks for image classification</a></p>"

inputs = gr.inputs.Image()
outputs = gr.outputs.Label(num_top_classes=3)
io = gr.Interface(fn=predict, inputs=inputs, outputs=outputs, title=title, article = article)
io.launch(share=True)

Colab notebook detected. To show errors in colab notebook, set `debug=True` in `launch()`
This share link will expire in 24 hours. If you need a permanent link, visit: https://gradio.app/introducing-hosted (NEW!)
Running on External URL: https://37223.gradio.app
Interface loading below...


(<Flask 'gradio.networking'>,
 'http://127.0.0.1:7866/',
 'https://37223.gradio.app')

In [22]:
io.integrate(wandb=wandb)
run.finish()