In [1]:
import torch
import torch.nn as nn
import contextlib
from transformers import AutoTokenizer, OPTForCausalLM, OPTConfig
from models.Qformer import *
from utils.data import build_dataloader_from_yaml
import torch.nn.functional as F
cur_device = torch.device("cuda:6")


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class MLP(nn.Module):

    def __init__(
            self,
            num_query_token=32,
            clip_vec_len=1408,
            sam_vec_len=4096,
            cls_num = 14
        ):
        super().__init__()
        self.clip_qformer = Qformer(fecture_vec_len=clip_vec_len, num_query_token=num_query_token+1, cross_attention_freq=2)
        self.sam_qformer = Qformer(fecture_vec_len=sam_vec_len, num_query_token=num_query_token+1, cross_attention_freq=2)
        self.cls_num = cls_num
        
        self.fc1 = nn.Linear(1536, 100)  # 输入大小为1536，输出大小为100
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(100, 42)    # 输入大小为100，输出大小为10
        self.sigmoid = nn.Sigmoid()
        self.lossfn = nn.BCELoss()
        self.target = torch.rand((4,42,))

    def forward(self, samples):
        clip_features = samples['clip_features']
        sam_features = samples['sam_features']
        # text_input = samples['text_input']
        cur_device = clip_features.device
        # print("clip shape: ", clip_features.shape, "sam shape: ", sam_features.shape)

        # CLIP
        clip_query_output = self.clip_qformer(
            features=clip_features,
            attention_mask=torch.ones(clip_features.size()[:-1], dtype=torch.long).to(cur_device)
        )

        # SAM
        sam_features = sam_features.to(cur_device)
        sam_query_output = self.sam_qformer(
            features=sam_features,
            attention_mask=torch.ones(sam_features.size()[:-1], dtype=torch.long).to(cur_device),
        )

        clip_cls = sam_query_output[:, -1, :]
        sam_cls = clip_query_output[:, -1, :]

        cat_cls = torch.cat([clip_cls, sam_cls], dim=1)
        print(cat_cls.shape)
        x = self.fc1(cat_cls)
        x = self.relu(x)
        x = self.fc2(x)
        x = x.view(-1, 14, 3)
        x = F.softmax(x, dim=2)
        x = x.view(-1, 42)
        print(x)
        loss = self.lossfn(x, self.target)
        return loss


In [3]:
class simplenet(nn.Module):
    def __init__(
            self,
        ):
        super().__init__()
        
        self.fc1 = nn.Linear(1536, 100)  # 输入大小为1536，输出大小为100
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(100, 42)    # 输入大小为100，输出大小为10
        self.sigmoid = nn.Sigmoid()
        self.lossfn = nn.CrossEntropyLoss()
        self.target = torch.rand((42*4,))
    
    def forward(self, cat_cls):
        # text_input = samples['text_input']
        # print("clip shape: ", clip_features.shape, "sam shape: ", sam_features.shape)

        # CLIP

        
        # MLP
        x = self.fc1(cat_cls)
        x = self.relu(x)
        x = self.fc2(x)
        x = x.view(-1, 14, 3)
        x = F.softmax(x, dim=2)
        x = x.view(-1, 42)
        print(x)
        # loss = self.lossfn(x, self.target)
        # return loss
    
    
    

In [4]:
anet = simplenet()

In [5]:
def forward(mlp_model, samples, simplenet):
    clip_features = samples['clip_features']
    sam_features = samples['sam_features']
    # text_input = samples['text_input']
    cur_device = clip_features.device
    # print("clip shape: ", clip_features.shape, "sam shape: ", sam_features.shape)

    # CLIP
    clip_query_output = mlp_model.clip_qformer(
        features=clip_features,
        attention_mask=torch.ones(clip_features.size()[:-1], dtype=torch.long).to(cur_device)
    )

    # SAM
    sam_features = sam_features.to(cur_device)
    sam_query_output = mlp_model.sam_qformer(
        features=sam_features,
        attention_mask=torch.ones(sam_features.size()[:-1], dtype=torch.long).to(cur_device),
    )
    print(sam_query_output.shape)
    clip_cls = sam_query_output[:, -1, :]
    sam_cls = clip_query_output[:, -1, :]
    print(clip_cls.shape, sam_cls.shape)
    cat_cls = torch.cat([clip_cls, sam_cls], dim=1)
    print(cat_cls.shape)
    simplenet(cat_cls)

    

In [6]:
configpath = "/home/xcg/medical-research/Project23us/config/train.yaml"
custom_dataloader = build_dataloader_from_yaml(configpath)

In [7]:
for clip_feature, sam_feature, caption in custom_dataloader:
#     clip_feature = clip_feature.to(cur_device)
#     sam_feature = sam_feature.to(cur_device)
    clip_shape = clip_feature.shape
    sam_shape = sam_feature.shape
#     print(clip_feature.shape, sam_feature.shape)
    my_samples = {
            'sam_features': sam_feature.view(sam_shape[0], sam_shape[1]*sam_shape[2], sam_shape[3]),
            'clip_features': clip_feature.view(clip_shape[0], clip_shape[1]*clip_shape[2], clip_shape[3]),
            'text_input': caption,
    }
    
    loss = mlp_model(my_samples)
    break


NameError: name 'mlp_model' is not defined

In [None]:
loss

tensor(0.7561, grad_fn=<BinaryCrossEntropyBackward0>)

In [None]:
from utils.runner import *
from utils.data import build_dataloader
from tqdm import tqdm

class MLPRunner(RunnerBase):
    def __init__(
        self,
        model,
        cfg,
    ):
        config = self.build_config(cfg)
        optimizer = self.build_optimizer(model, config)
        dataloader = build_dataloader(config)
        max_epoch = config["run"]["max_epoch"]
        device = config["run"]["device"]
        device = torch.device('cpu')
        super().__init__(
            model=model,
            optimizer=optimizer,
            dataloader=dataloader,
            max_epoch=max_epoch,
            device=device,
        )
        self.config = config
    
    def train_step(self, samples):
        clip_shape = samples[0].shape
        sam_shape = samples[1].shape
        
        # print(samples[1].view(clip_shape[0], clip_shape[1]*clip_shape[2], clip_shape[3]).shape)
        my_samples = {
            'sam_features': samples[1].view(sam_shape[0], sam_shape[1]*sam_shape[2], sam_shape[3]).to(self.device),
            'clip_features': samples[0].view(clip_shape[0], clip_shape[1]*clip_shape[2], clip_shape[3]).to(self.device),
            'text_input': samples[2],
        }
        # print(my_samples)
        loss = self.model(my_samples)
        return loss

    def train_epoch(self):
        for samples in tqdm(self.dataloader):
            with torch.cuda.amp.autocast(enabled=True):
                loss = self.train_step(samples)
            self.scaler.scale(loss).backward()
            self.scaler.step(self.optimizer)
            self.scaler.update()
            self.optimizer.zero_grad()

    def epoch_start_hook(self, info):
        pass

    def epoch_end_hook(self, info):
        # also save MLP
        torch.save({
            'epoch': info['cur_epoch'],  # 假设你训练了5个epochs
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
        }, f"./checkpoints/mlp_checkpoint_{info['cur_epoch']}.pth")
        print(info)

    def build_config(self, cfg):
        with open(cfg, 'r') as file:
            _config = yaml.load(file, Loader=yaml.FullLoader)
        return _config
    
    @classmethod
    def build_optimizer(self, model, config):
        # lr_scale = config["run"]["lr_layer_decay"]
        # weight_decay = config["run"]["weight_decay"]
        optim_params = model.parameters()
        # optim_params = self.model.Parameters()

        # num_parameters = 0
        # for p_group in optim_params:
        #     for p in p_group["params"]:
        #         num_parameters += p.data.nelement()    
        # logging.info("number of trainable parameters: {}".format(num_parameters))      
                
        beta2 = config["run"]["beta2"]

        _optimizer = torch.optim.AdamW(
            optim_params,
            lr=float(config["run"]["init_lr"]),
            betas=(0.9, beta2),
        )    
        return _optimizer

In [None]:
runner = MLPRunner(mlp_model, "/home/xcg/medical-research/Project23us/config/train.yaml")



In [None]:
runner.train_epoch()