In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial

from timm.models.vision_transformer import _cfg, PatchEmbed
from timm.models.registry import register_model
from timm.models.layers import trunc_normal_, DropPath
from transformers import AutoTokenizer, AutoModel
from torchvision import transforms


import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import json
import itertools
from rdkit import Chem
from rdkit.Chem import AllChem
import tensorflow as tf




In [2]:
class Mlp(nn.Module):
    """ MLP as used in Vision Transformer, MLP-Mixer and related networks
    """

    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

# attention层
class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
        self.scale = qk_scale or head_dim ** -0.5
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        self.attn_gradients = None
        self.attention_map = None

    def save_attn_gradients(self, attn_gradients):
        self.attn_gradients = attn_gradients

    def get_attn_gradients(self):
        return self.attn_gradients

    def save_attention_map(self, attention_map):
        self.attention_map = attention_map

    def get_attention_map(self):
        return self.attention_map

    def forward(self, x, register_hook=False):
        B, N, C = x.shape
        # (B, N, C) -> (B, N, 3, num_heads, nums_per_head) -> (3,B,num_heads,N,nums_per_head)
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        # (B,num_heads,N,nums_per_head)
        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)
        # (B,num_heads,N,nums_per_head) @ (B,num_heads,nums_per_head,N) -> (B,num_heads,N,N)
        attn = (q @ k.transpose(-2, -1)) * self.scale
        # softmax(Q @ K.T)
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        if register_hook:
            self.save_attention_map(attn)
            attn.register_hook(self.save_attn_gradients)
            # (B,num_heads,N,N) @ (B,num_heads,N,nums_per_head) -> (B,num_heads,N,nums_per_head) -> (B,N,num_heads,nums_per_head) -> (B,N,C)
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        # linear projection
        x = self.proj(x)
        # linear dropout
        x = self.proj_drop(x)
        return x


class Block(nn.Module):

    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(
            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

    def forward(self, x, register_hook=False):
        x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook))
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x


# Vit
class VisionTransformer(nn.Module):
    """ Vision Transformer
    A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`  -
        https://arxiv.org/abs/2010.11929
    """

    def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
                 num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None,
                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None):
        """
        Args:
            img_size (int, tuple): input image size
            patch_size (int, tuple): patch size
            in_chans (int): number of input channels
            num_classes (int): number of classes for classification head
            embed_dim (int): embedding dimension
            depth (int): depth of transformer
            num_heads (int): number of attention heads
            mlp_ratio (int): ratio of mlp hidden dim to embedding dim
            qkv_bias (bool): enable bias for qkv if True
            qk_scale (float): override default qk scale of head_dim ** -0.5 if set
            representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
            drop_rate (float): dropout rate
            attn_drop_rate (float): attention dropout rate
            drop_path_rate (float): stochastic depth rate
            norm_layer: (nn.Module): normalization layer
        """
        super().__init__()
        self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models
        norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)

        self.patch_embed = PatchEmbed(
            img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
        num_patches = self.patch_embed.num_patches

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        self.pos_drop = nn.Dropout(p=drop_rate)

        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule
        self.blocks = nn.ModuleList([
            Block(
                dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
            for i in range(depth)])
        self.norm = norm_layer(embed_dim)

        trunc_normal_(self.pos_embed, std=.02)
        trunc_normal_(self.cls_token, std=.02)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    @torch.jit.ignore
    def no_weight_decay(self):
        return {'pos_embed', 'cls_token'}

    def forward(self, x, register_blk=-1):
        B = x.shape[0]
        x = self.patch_embed(x)

        cls_tokens = self.cls_token.expand(B, -1, -1)  # stole cls_tokens impl from Phil Wang, thanks
        x = torch.cat((cls_tokens, x), dim=1)

        x = x + self.pos_embed[:, :x.size(1), :]
        x = self.pos_drop(x)

        for i, blk in enumerate(self.blocks):
            x = blk(x, register_blk == i)
        x = self.norm(x)

        return x



def interpolate_pos_embed(pos_embed_checkpoint, visual_encoder):
    # interpolate position embedding
    embedding_size = pos_embed_checkpoint.shape[-1]
    #
    num_patches = visual_encoder.patch_embed.num_patches
    # pos_embed.shape == (1, num_patches + 1, embed_dim)
    num_extra_tokens = visual_encoder.pos_embed.shape[-2] - num_patches  # 1
    # height (== width) for eckpoint position embedding
    #     # pos_embed_checkpoint.shape =the ch= (1,197,768)
    # 196 ** 0.5 = 14 :orig_size
    orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
    # height (== width) for the new position embedding
    new_size = int(num_patches ** 0.5)
    # patches
    if orig_size != new_size:
        # class_token and dist_token are kept unchanged
        extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
        # only the position tokens are interpolated
        pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
        pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
        pos_tokens = torch.nn.functional.interpolate(
            pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
        # (b,c,h,w)->(b,h,w,c)->(b,h*w,c)
        pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
        #
        new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
        print('reshape position embedding from %d to %d' % (orig_size ** 2, new_size ** 2))

        return new_pos_embed
    else:
        return pos_embed_checkpoint


def createVisualModel(resolution_after=224):
    visual_encoder = VisionTransformer(
        img_size=resolution_after, patch_size=16, embed_dim=768, depth=12, num_heads=12,
        mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6))
    checkpoint = torch.hub.load_state_dict_from_url(
            url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth",
        map_location="cpu", check_hash=True)
    state_dict = checkpoint["model"]
    pos_embed_reshaped = interpolate_pos_embed(state_dict['pos_embed'], visual_encoder)
    state_dict['pos_embed'] = pos_embed_reshaped
    visual_encoder.load_state_dict(state_dict, strict=False)
    return visual_encoder

In [3]:
visual_encoder = createVisualModel(resolution_after=224)

In [4]:
import torch
from torchvision import transforms
from rdkit import Chem
from rdkit.Chem import Draw
from PIL import Image

In [5]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize the image to 224x224 pixels
    transforms.ToTensor(),  # Convert the image to a PyTorch tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # Normalize with ImageNet's parameters
])

In [6]:
def image_feature(smiles, visual_encoder):
    molecule = Chem.MolFromSmiles(smiles)
    molecule_image = Draw.MolToImage(molecule, size=(224, 224))
    molecule_image_pil = molecule_image.convert('RGB')
    molecule_image_tensor = transform(molecule_image_pil)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    molecule_image_tensor = molecule_image_tensor.unsqueeze(0).to(device)

    # Get features from the visual_encoder
    with torch.no_grad():  # Ensure no gradients are computed to save memory and computations
        features = visual_encoder(molecule_image_tensor)

    # Detach the tensor and convert to numpy
    features_numpy = features.cpu().detach().numpy()
    return features_numpy

In [7]:
from transformers import AutoTokenizer, AutoModel

tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
model = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")

In [8]:
def string_feature(smiles):
    inputs = tokenizer(smiles, return_tensors="pt", padding=True, truncation=True)
    outputs = model(**inputs)
    hidden_states = outputs.last_hidden_state
    cls_embedding = hidden_states[:, 0, :]
    return(cls_embedding )

In [10]:
a = 'C/C(=C\\c1csc(C)n1)[C@@H]1C[C@@H]2O[C@]2(C)CCC[C@H](C)[C@H](O)[C@@H](C)C(=O)C(C)(C)[C@@H](O)CC(=O)N1'

In [15]:
image_feature(a, visual_encoder).shape

(1, 197, 768)

In [14]:
string_feature(a).shape

torch.Size([1, 768])

In [None]:
image_features = np.zeros((merged_df.shape[0], 197, 768))

for i, smile in enumerate(merged_df.SMILES):
    image_features[i] = image_feature(smile, visual_encoder).squeeze(0)
    
import h5py

with h5py.File('test_image_features.h5', 'w') as hf:
    hf.create_dataset('test_image_features', data=image_features)

In [None]:
image_features.shape

In [None]:
string_features = pd.DataFrame(0, range(merged_df.shape[0]), range(768))
for i in range(merged_df.shape[0]):
    A = string_feature(merged_df.SMILES[i])
    A_np = A.squeeze(0).detach().numpy()
    string_features.iloc[i] = A_np
    
string_features.to_csv('test_string_features.csv', index=False)

In [None]:
string_features.shape