In [1]:
import os
import torch
import numpy as np
import pandas as pd
from rdkit import Chem
from rdkit.Chem import Draw
from PIL import Image
import timm
from torch import nn
from torchvision import transforms
from tqdm import tqdm
import pandas as pd

In [2]:
# Encoder 클래스 정의
class Encoder(nn.Module):
    def __init__(self, args, pretrained=False):
        super().__init__()
        model_name = args.encoder
        self.model_name = model_name
        if model_name.startswith('resnet'):
            self.model_type = 'resnet'
            self.cnn = timm.create_model(model_name, pretrained=pretrained)
            self.n_features = self.cnn.num_features  # encoder_dim
            self.cnn.global_pool = nn.Identity()
            self.cnn.fc = nn.Identity()
        elif model_name.startswith('swin'):
            self.model_type = 'swin'
            self.transformer = timm.create_model(model_name, pretrained=pretrained, pretrained_strict=False,
                                                 use_checkpoint=args.use_checkpoint)
            self.n_features = self.transformer.num_features
            self.transformer.head = nn.Identity()
        elif 'efficientnet' in model_name:
            self.model_type = 'efficientnet'
            self.cnn = timm.create_model(model_name, pretrained=pretrained)
            self.n_features = self.cnn.num_features
            self.cnn.global_pool = nn.Identity()
            self.cnn.classifier = nn.Identity()
        else:
            raise NotImplemented

    def swin_forward(self, transformer, x):
        x = transformer.patch_embed(x)
        if transformer.absolute_pos_embed is not None:
            x = x + transformer.absolute_pos_embed
        x = transformer.pos_drop(x)

        def layer_forward(layer, x, hiddens):
            for blk in layer.blocks:
                if not torch.jit.is_scripting() and layer.use_checkpoint:
                    x = torch.utils.checkpoint.checkpoint(blk, x)
                else:
                    x = blk(x)
            H, W = layer.input_resolution
            B, L, C = x.shape
            hiddens.append(x.view(B, H, W, C))
            if layer.downsample is not None:
                x = layer.downsample(x)
            return x, hiddens

        hiddens = []
        for layer in transformer.layers:
            x, hiddens = layer_forward(layer, x, hiddens)
        x = transformer.norm(x)  # B L C
        hiddens[-1] = x.view_as(hiddens[-1])
        return x, hiddens

    def forward(self, x):
        if self.model_type in ['resnet', 'efficientnet']:
            features = self.cnn(x)
            features = features.permute(0, 2, 3, 1)
        elif self.model_type == 'swin':
            if 'patch' in self.model_name:
                features, _ = self.swin_forward(self.transformer, x)  # Adjust here if more outputs
            else:
                features, _ = self.transformer(x)  # Adjust here if more outputs
        else:
            raise NotImplemented
        return features  # Return only features

In [3]:
# 인코더 모델을 초기화하기 위한 args 설정
class Args:
    def __init__(self):
        self.encoder = 'swin_base_patch4_window7_224'  # 사용할 인코더 모델 이름
        self.use_checkpoint = False  # 스윈 트랜스포머에서 체크포인트 사용 여부

In [4]:
# 모델 로드
args = Args()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
encoder = Encoder(args, pretrained=False).to(device)

# 훈련된 모델 파라미터 로드
model_path = '../ckpt/swin_base_char_aux_1m680k.pth'  # 실제 모델 파일 경로로 바꾸세요
checkpoint = torch.load(model_path, map_location=device)
state_dict = checkpoint['encoder'] if 'encoder' in checkpoint else checkpoint
encoder.load_state_dict(state_dict, strict=False)
encoder.eval()

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


Encoder(
  (transformer): SwinTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 128, kernel_size=(4, 4), stride=(4, 4))
      (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    )
    (pos_drop): Dropout(p=0.0, inplace=False)
    (layers): Sequential(
      (0): BasicLayer(
        dim=128, input_resolution=(56, 56), depth=2
        (blocks): ModuleList(
          (0): SwinTransformerBlock(
            (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
            (attn): WindowAttention(
              (qkv): Linear(in_features=128, out_features=384, bias=True)
              (attn_drop): Dropout(p=0.0, inplace=False)
              (proj): Linear(in_features=128, out_features=128, bias=True)
              (proj_drop): Dropout(p=0.0, inplace=False)
              (softmax): Softmax(dim=-1)
            )
            (drop_path): Identity()
            (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
            (mlp): Mlp(
         

In [5]:
# 데이터 준비 및 전처리
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [6]:
def smiles_to_image(smiles, img_size=(256, 256)):
    mol = Chem.MolFromSmiles(smiles)
    if mol is not None:
        return Draw.MolToImage(mol, size=img_size)
    else:
        return None

In [7]:
train_df = pd.read_csv('../train_data/train_data.csv')
val_df = pd.read_csv('../train_data/val_data.csv')

print(train_df)
print(val_df)

                                                  Smiles  \
0            Nc1ccc2c(c1)/C(=C\c1ccc(-c3ccoc3)o1)C(=O)N2   
1      CCc1nc(C(N)=O)c(Nc2ccc(N3CCN(C)CC3)cc2)nc1N[C@...   
2      O=C(Nc1cccc(CNc2ncnc3c2cnn3CCc2ccccc2)c1)c1ccc...   
3      O=S(=O)(Nc1cccc(CNc2ncnc3c2cnn3CCc2ccccc2)c1)c...   
4      O=S(=O)(Nc1cccc(CNc2ncnc3c2cnn3CCc2ccccc2)c1)c...   
...                                                  ...   
90210  Cc1nc(-c2cccc(Nc3nc(C)c(-c4cc5c(c(S(C)(=O)=O)c...   
90211  Cc1nc(Nc2cccc(-c3nc(C)n(C)n3)n2)sc1-c1cc2c(c(S...   
90212  Cn1nccc1-n1c([C@@H]2CSCN2c2ncnc3[nH]cnc23)nc2c...   
90213  COc1cc(-c2cc3c(O[C@H]4CCN(C(=O)C5CCC5)C4)ncnc3...   
90214    COc1cccc2cc(-c3nn(C4CCOCC4)c4ncnc(N)c34)[nH]c12   

                                Target Name  Standard Value  \
0      Tyrosine-protein kinase receptor RET         34000.0   
1      Tyrosine-protein kinase receptor RET             1.1   
2      Tyrosine-protein kinase receptor RET          9310.0   
3      Tyrosine-protein kin

In [8]:
for df in [train_df, val_df]:
    feature_vectors = []
    for smiles in tqdm(df['Smiles']):
        img = smiles_to_image(smiles)
        if img is not None:
            img = img.convert('RGB')
            img_tensor = transform(img).unsqueeze(0).to(device)
            features = encoder(img_tensor).detach().cpu().numpy().flatten()
            feature_vectors.append(features)
        else:
            feature_vector.append(np.nan)
    df['image_feature_vector'] = feature_vectors

100%|██████████| 90215/90215 [16:42<00:00, 89.98it/s]
100%|██████████| 1952/1952 [00:21<00:00, 89.44it/s]


In [9]:
print(train_df.head())
print(val_df.head())

                                              Smiles  \
0        Nc1ccc2c(c1)/C(=C\c1ccc(-c3ccoc3)o1)C(=O)N2   
1  CCc1nc(C(N)=O)c(Nc2ccc(N3CCN(C)CC3)cc2)nc1N[C@...   
2  O=C(Nc1cccc(CNc2ncnc3c2cnn3CCc2ccccc2)c1)c1ccc...   
3  O=S(=O)(Nc1cccc(CNc2ncnc3c2cnn3CCc2ccccc2)c1)c...   
4  O=S(=O)(Nc1cccc(CNc2ncnc3c2cnn3CCc2ccccc2)c1)c...   

                            Target Name  Standard Value  \
0  Tyrosine-protein kinase receptor RET         34000.0   
1  Tyrosine-protein kinase receptor RET             1.1   
2  Tyrosine-protein kinase receptor RET          9310.0   
3  Tyrosine-protein kinase receptor RET          4830.0   
4  Tyrosine-protein kinase receptor RET         12850.0   

   Smiles_feature_vector                               image_feature_vector  \
0                    NaN  [-1.4555347, 0.60381854, -0.99908453, 1.393386...   
1                    NaN  [-1.4356725, 0.5544013, -0.9523898, 1.3807472,...   
2                    NaN  [-1.4010823, 0.5928811, -0.972553, 1.3960526,

In [10]:
train_df.to_csv('../train_data/train_data_img.csv')
val_df.to_csv('../train_data/val_data_img.csv')