In [1]:
import torch
from utils.data import build_dataloader
import yaml
import importlib
import utils
from utils.SamBlipRunner import SamBlipRunner
from models.SamBlip import *
import torch.nn.functional as F

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model = SamBlip()
device = torch.device("cuda:0")
model.load_state_dict(torch.load("sam_blip_pretrained.pth", map_location = "cpu")) 
model = model.to(device)
runner.train()

Some weights of BertLMHeadModel were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['bert.encoder.layer.4.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.1.intermediate_query.dense.weight', 'bert.encoder.layer.8.crossattention.self.key.weight', 'bert.encoder.layer.6.output_query.LayerNorm.bias', 'bert.encoder.layer.9.output_query.dense.weight', 'bert.encoder.layer.6.output_query.dense.weight', 'bert.encoder.layer.8.output_query.dense.bias', 'bert.encoder.layer.4.crossattention.output.dense.bias', 'bert.encoder.layer.5.intermediate_query.dense.weight', 'bert.encoder.layer.0.crossattention.self.key.weight', 'bert.encoder.layer.0.output_query.dense.weight', 'bert.encoder.layer.4.crossattention.self.value.bias', 'bert.encoder.layer.11.output_query.dense.weight', 'bert.encoder.layer.4.crossattention.self.query.bias', 'bert.encoder.layer.2.crossattention.self.key.weight', 'bert.encoder.layer.2.crossattention.output.LayerNorm.weight', '

In [5]:
from utils.runner import *
from utils.data import build_dataloader
from tqdm import tqdm

class SamBlipRunner2(RunnerBase):
    def __init__(
        self,
        model,
        cfg,
    ):
        config = self.build_config(cfg)
        optimizer = self.build_optimizer(model, config)
        dataloader = build_dataloader(config)
        max_epoch = config["run"]["max_epoch"]
        device = config["run"]["device"]
        super().__init__(
            model=model,
            optimizer=optimizer,
            dataloader=dataloader,
            max_epoch=max_epoch,
            device=device,
        )
        self.config = config
    

    def train_step(self, samples):
        clip_shape = samples[0].shape
        sam_shape = samples[1].shape
        
        # print(samples[1].view(clip_shape[0], clip_shape[1]*clip_shape[2], clip_shape[3]).shape)
        my_samples = {
            'sam_features': samples[1].view(sam_shape[0], sam_shape[1]*sam_shape[2], sam_shape[3]).to(self.device),
            'clip_features': samples[0].view(clip_shape[0], clip_shape[1]*clip_shape[2], clip_shape[3]).to(self.device),
            'text_input': samples[2],
        }
        
        output = self.model(my_samples)
        return output['loss']

    def train_epoch(self):
        for samples in tqdm(self.dataloader):
            with torch.cuda.amp.autocast(enabled=True):
                loss = self.train_step(samples)
            self.scaler.scale(loss).backward()
            self.scaler.step(self.optimizer)
            self.scaler.update()
            self.optimizer.zero_grad()

    def epoch_start_hook(self, info):
        pass

    def epoch_end_hook(self, info):
        torch.save({
            'epoch': info['cur_epoch'],  # 假设你训练了5个epochs
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
        }, f"./checkpoints/sam_blip_checkpoint_{info['cur_epoch']}.pth")
        print(info)

    def build_config(self, cfg):
        with open(cfg, 'r') as file:
            _config = yaml.load(file, Loader=yaml.FullLoader)
        return _config
    
    @classmethod
    def build_optimizer(self, model, config):
        lr_scale = config["run"]["lr_layer_decay"]
        weight_decay = config["run"]["weight_decay"]
        optim_params = model.get_optimizer_params(weight_decay, lr_scale)
        # optim_params = self.model.Parameters()

        num_parameters = 0
        for p_group in optim_params:
            for p in p_group["params"]:
                num_parameters += p.data.nelement()    
        logging.info("number of trainable parameters: {}".format(num_parameters))      
                
        beta2 = config["run"]["beta2"]

        _optimizer = torch.optim.AdamW(
            optim_params,
            lr=float(config["run"]["init_lr"]),
            betas=(0.9, beta2),
        )    
        return _optimizer

In [6]:

runner = SamBlipRunner2(model, "/home/xcg/medical-research/Project23us/config/train.yaml")

In [7]:
runner.train()

100%|██████████| 804/804 [28:51<00:00,  2.15s/it]


{'cur_epoch': 0}


  1%|          | 7/804 [00:08<16:21,  1.23s/it]


In [9]:
import numpy as np
device = torch.device("cuda:0")
clip_features = np.load("/data2/xcg_data/lavis_data/2023us/features/clip_features/20230101_11548959_012117568.npz")['arr']
sam_features = np.load("/data2/xcg_data/lavis_data/2023us/features/sam_features/20230101_11548959_012117568.npz")['arr']
clip_features = torch.tensor(clip_features).to(device).unsqueeze(0)
sam_features = torch.tensor(sam_features).to(device).unsqueeze(0)

In [11]:
output = model.generate(dict(clip_features=clip_features, sam_features=sam_features, prompt=""),max_length=100)
output

["The umbilical cord is attached to the placenta. The amniotic sac has a diameter of about 2 cm, and there are 3-4 small blood vessels with an average diameter of 0.15 - 0.20 cm in the upper part of the amniotic sac. No fetal movement can be seen inside the amniotic sac at this time. Pregnant woman's body mass index (BMI) is 28.5 kg/m2; gestational age"]

In [None]:
model = nn.Sequential(
    nn.Linear(768, 256), nn.ReLU(),
    nn.Linear(256, 256), nn.ReLU(),
    nn.Linear(256, 256), nn.ReLU(),
    nn.Linear(256, 7)
)
model_output = model(model_input)
F.binary_cross_entropy(F.sigmoid(model_output), target)