In [1]:
#存默一贴摁卡，图

In [1]:
import os
import math
import random
from sklearn.metrics import balanced_accuracy_score
import copy
from functools import wraps, partial

import pandas as pd
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms

from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, CosineAnnealingLR
import torch.backends.cudnn as cudnn

from einops import repeat, rearrange 
from einops.layers.torch import Rearrange

import matplotlib.pyplot as plt
from torch.utils.data.dataloader import default_collate

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
path = "/data2/patho-vit_5_23_ccsurv/up"
# 此处patho_vit后的_5_23为包的版本号
# 若此jupyternotebook运行中kernel挂掉，重启后仅需运行此一代码块，然后跳到需要运行的代码块即可

In [3]:
seed = 42

def seed_everything(seed):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

seed_everything(seed)

In [4]:
# 第二部门：机场1——全面训练海绵宝宝
batch_size = 260
# 原始文献为50，若报错CUDA OUT OF MEMORY,则减少batch_size个数。此外，需关闭kernel重启，只重启机场1代码即可。
num_workers = 48
# 原始文献此处为4

epochs = 10000
# 原始文献中默认值为50
#test_every = 1
# 原始文献中默认值为10

#n = 0
# n在0到 400/25 = 16之间

In [5]:
import math

from multiprocessing import Value

from logging import getLogger

import torch

_GLOBAL_SEED = 0
logger = getLogger()


class MaskCollator(object):

    def __init__(
        self,
        input_size = (384, 384),
        patch_size = 16,
        enc_mask_scale = (0.85, 1.0),
        pred_mask_scale = (0.15, 0.2),
        aspect_ratio = (0.75, 1.5),
        nenc = 1,
        npred = 4,
        min_keep = 4,
        allow_overlap = False
    ):
        super().__init__()
        if not isinstance(input_size, tuple):
            input_size = (input_size, ) * 2
        self.patch_size = patch_size
        self.height, self.width = input_size[0] // patch_size, input_size[1] // patch_size
        self.enc_mask_scale = enc_mask_scale
        self.pred_mask_scale = pred_mask_scale
        self.aspect_ratio = aspect_ratio
        self.nenc = nenc
        self.npred = npred
        self.min_keep = min_keep  # minimum number of patches to keep
        self.allow_overlap = allow_overlap  # whether to allow overlap b/w enc and pred masks
        self._itr_counter = Value('i', -1)  # collator is shared across worker processes

    def step(self):
        i = self._itr_counter
        with i.get_lock():
            i.value += 1
            v = i.value
        return v

    def _sample_block_size(self, generator, scale, aspect_ratio_scale):
        _rand = torch.rand(1, generator=generator).item()
        # -- Sample block scale
        min_s, max_s = scale
        mask_scale = min_s + _rand * (max_s - min_s)
        max_keep = int(self.height * self.width * mask_scale)
        # -- Sample block aspect-ratio
        min_ar, max_ar = aspect_ratio_scale
        aspect_ratio = min_ar + _rand * (max_ar - min_ar)
        # -- Compute block height and width (given scale and aspect-ratio)
        h = int(round(math.sqrt(max_keep * aspect_ratio)))
        w = int(round(math.sqrt(max_keep / aspect_ratio)))
        while h >= self.height:
            h -= 1
        while w >= self.width:
            w -= 1

        return (h, w)

    def _sample_block_mask(self, b_size, acceptable_regions=None):
        h, w = b_size

        def constrain_mask(mask, tries=0):
            """ Helper to restrict given mask to a set of acceptable regions """
            N = max(int(len(acceptable_regions)-tries), 0)
            for k in range(N):
                mask *= acceptable_regions[k]
        # --
        # -- Loop to sample masks until we find a valid one
        tries = 0
        timeout = og_timeout = 20
        valid_mask = False
        while not valid_mask:
            # -- Sample block top-left corner
            top = torch.randint(0, self.height - h, (1,))
            left = torch.randint(0, self.width - w, (1,))
            mask = torch.zeros((self.height, self.width), dtype=torch.int32)
            mask[top:top+h, left:left+w] = 1
            # -- Constrain mask to a set of acceptable regions
            if acceptable_regions is not None:
                constrain_mask(mask, tries)
            mask = torch.nonzero(mask.flatten())
            # -- If mask too small try again
            valid_mask = len(mask) > self.min_keep
            if not valid_mask:
                timeout -= 1
                if timeout == 0:
                    tries += 1
                    timeout = og_timeout
                    logger.warning(f'Mask generator says: "Valid mask not found, decreasing acceptable-regions [{tries}]"')
        mask = mask.squeeze()
        # --
        mask_complement = torch.ones((self.height, self.width), dtype=torch.int32)
        mask_complement[top:top+h, left:left+w] = 0
        # --
        return mask, mask_complement

    def __call__(self, batch):
        '''
        Create encoder and predictor masks when collating imgs into a batch
        # 1. sample enc block (size + location) using seed
        # 2. sample pred block (size) using seed
        # 3. sample several enc block locations for each image (w/o seed)
        # 4. sample several pred block locations for each image (w/o seed)
        # 5. return enc mask and pred mask
        '''
        B = len(batch)

        collated_batch = default_collate(batch)

        seed = self.step()
        g = torch.Generator()
        g.manual_seed(seed)
        p_size = self._sample_block_size(
            generator=g,
            scale=self.pred_mask_scale,
            aspect_ratio_scale=self.aspect_ratio)
        e_size = self._sample_block_size(
            generator=g,
            scale=self.enc_mask_scale,
            aspect_ratio_scale=(1., 1.))

        collated_masks_pred, collated_masks_enc = [], []
        min_keep_pred = self.height * self.width
        min_keep_enc = self.height * self.width
        for _ in range(B):

            masks_p, masks_C = [], []
            for _ in range(self.npred):
                mask, mask_C = self._sample_block_mask(p_size)
                masks_p.append(mask)
                masks_C.append(mask_C)
                min_keep_pred = min(min_keep_pred, len(mask))
            collated_masks_pred.append(masks_p)

            acceptable_regions = masks_C
            try:
                if self.allow_overlap:
                    acceptable_regions= None
            except Exception as e:
                logger.warning(f'Encountered exception in mask-generator {e}')

            masks_e = []
            for _ in range(self.nenc):
                mask, _ = self._sample_block_mask(e_size, acceptable_regions=acceptable_regions)
                masks_e.append(mask)
                min_keep_enc = min(min_keep_enc, len(mask))
            collated_masks_enc.append(masks_e)

        collated_masks_pred = [[cm[:min_keep_pred] for cm in cm_list] for cm_list in collated_masks_pred]
        collated_masks_pred = default_collate(collated_masks_pred)
        # --
        collated_masks_enc = [[cm[:min_keep_enc] for cm in cm_list] for cm_list in collated_masks_enc]
        collated_masks_enc = default_collate(collated_masks_enc)

        return collated_batch, collated_masks_enc, collated_masks_pred

In [6]:
transform = transforms.Compose([
    #transforms.Resize([384, 384]),
    transforms.ToTensor(), 
    transforms.Normalize(mean=(0.485, 0.456, 0.406), 
                             std=(0.229, 0.224, 0.225))])


#train_lib = "files/store_5_train_{}.db".format((n+1) * 25)
train_lib = "/data2/patho-vit_5_23_ccsurv/SZL_exva/files/store_4.db"

#valid_lib = "files/store_3_valid.db"

from patho_vit.airport1 import pathovitdataset

train_dataset = pathovitdataset(libraryfile = train_lib, transform = transform)

#valid_dataset = pathovitdataset(libraryfile = train_lib, transform = transform)

#x_test = []
#y_test = []
#for i in range(4):
#    x = valid_dataset[i][0]
#    x_test.append(x)
#    y = valid_dataset[i][1]
#    y_test.append(y)
#test_dataset = [(x, y) for x, y in zip(x_test, y_test)] #包装为数据对

train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size = batch_size, shuffle = True,
    num_workers = num_workers,
    collate_fn = MaskCollator()
)
#valid_loader = torch.utils.data.DataLoader(
#    valid_dataset,
#    batch_size = batch_size, shuffle = False,
#    num_workers = num_workers
#)

In [7]:
def apply_mask(x, mask):
    all_x = []
    for m in mask:
        mask_keep = m.unsqueeze(-1).repeat(1, 1, x.size(-1))
        all_x += [torch.gather(x, dim=1, index=mask_keep)]
    return torch.cat(all_x, dim=0)

def singleton(cache_key):
    def inner_fn(fn):
        @wraps(fn)
        def wrapper(self, *args, **kwargs):
            instance = getattr(self, cache_key)
            if instance is not None:
                return instance
            instance = fn(self, *args, **kwargs)
            setattr(self, cache_key, instance)
            return instance
        return wrapper
    return inner_fn

def set_requires_grad(model, val):
    for p in model.parameters():
        p.requires_grad = val

def get_module_device(module):
    return next(module.parameters()).device

def loss_fn(z, h):
    loss = F.smooth_l1_loss(z, h)
    loss = loss.sum(dim = -1).mean()
    return loss

def pair(t):
    return t if isinstance(t, tuple) else (t, t)

class Attention(nn.Module):
    def __init__(self, dim, heads = 12, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = heads * dim_head 
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.norm = nn.LayerNorm(dim)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        x = self.norm(x)

        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)
        attn = self.dropout(attn)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')

        return self.to_out(out)

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.layers = nn.ModuleList([])
        
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
                FeedForward(dim, mlp_dim, dropout = dropout)
            ]))

    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x

        return self.norm(x)

In [8]:
class Context_encoder(nn.Module):
    def __init__(self, *, image_size = 384, patch_size = 16, channels = 3, emb_dropout = 0., 
                 dim = 768, depth = 12, heads = 12, dim_head = 64, mlp_dim = 768 * 4, dropout = 0.):
        super().__init__()
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)
        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        num_patches = (image_height // patch_height) * (image_width // patch_width)
        patch_dim = channels * patch_height * patch_width

        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, dim),
            nn.LayerNorm(dim)
        )

        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))

        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

    def forward(self, img, mask):
        device = img.device
        
        x = self.to_patch_embedding(img)
        b, n, _ = x.shape

        cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
        x = torch.cat((cls_tokens, x), dim=1)
        
        x += self.pos_embedding[:, :(n + 1)]
        
        x = self.dropout(x)

        x = self.transformer(x)
        
        x = apply_mask(x, mask)
        
        return x

In [9]:
class Predictor(nn.Module):
    def __init__(self, *, num_patches = 576, encoder_dim = 768, 
                 decoder_dim = 768, decoder_depth = 8, decoder_heads = 8, decoder_dim_head = 64):
        super().__init__()
        
        self.num_patches = num_patches
        self.decoder_dim = decoder_dim
        self.enc_to_dec = nn.Linear(encoder_dim, decoder_dim) if encoder_dim != decoder_dim else nn.Identity()
        self.black_token = nn.Parameter(torch.randn(1, 1, decoder_dim))
        self.decoder_pos_emb = nn.Embedding(num_patches, decoder_dim)
        self.decoder = Transformer(dim = decoder_dim, depth = decoder_depth, heads = decoder_heads, dim_head = decoder_dim_head, mlp_dim = decoder_dim * 4)

    def forward(self, x, context_mask, target_mask):
        device = x.device
        batch, num_white, _ = x.shape
        
        white_tokens = self.enc_to_dec(x)
        white_tokens += self.decoder_pos_emb(context_mask[0])
        
        black_tokens1 = repeat(self.black_token, '1 1 d -> b n d', b = batch, n = target_mask[0].size(-1))
        black_tokens1 = black_tokens1 + self.decoder_pos_emb(target_mask[0])
        
        black_tokens2 = repeat(self.black_token, '1 1 d -> b n d', b = batch, n = target_mask[1].size(-1))
        black_tokens2 = black_tokens2 + self.decoder_pos_emb(target_mask[1])
        
        black_tokens3 = repeat(self.black_token, '1 1 d -> b n d', b = batch, n = target_mask[2].size(-1))
        black_tokens3 = black_tokens3 + self.decoder_pos_emb(target_mask[2])
        
        black_tokens4 = repeat(self.black_token, '1 1 d -> b n d', b = batch, n = target_mask[3].size(-1))
        black_tokens4 = black_tokens4 + self.decoder_pos_emb(target_mask[3])
        
        decoder_tokens = torch.zeros(batch, self.num_patches, self.decoder_dim, device=device)
        batch_range = torch.arange(batch, device = device)[:, None]
        decoder_tokens[batch_range, context_mask[0]] = white_tokens
        decoder_tokens[batch_range, target_mask[0]] = black_tokens1
        decoder_tokens[batch_range, target_mask[1]] = black_tokens2
        decoder_tokens[batch_range, target_mask[2]] = black_tokens3
        decoder_tokens[batch_range, target_mask[3]] = black_tokens4
        
        decoder_tokens = self.decoder(decoder_tokens)
        
        context = apply_mask(decoder_tokens, target_mask)
        
        return context

In [10]:
class Jepa(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.context_encoder = Context_encoder()
        self.predictor = Predictor()
        #self.target_encoder = Target_encoder()
        self.target_encoder = None
        
        device = get_module_device(Context_encoder())
        self.to(device)
        
        self.forward(torch.randn(2, 3, 384, 384, device = device), torch.ones(2, 576, dtype  = torch.int64), [torch.ones(2, 576, dtype  = torch.int64), 
                                                                                                             torch.ones(2, 576, dtype  = torch.int64), 
                                                                                                              torch.ones(2, 576, dtype  = torch.int64),                                     
                                                                                                              torch.ones(2, 576, dtype  = torch.int64)])
    @singleton("target_encoder")
    def _get_target_encoder(self):
        target_encoder = copy.deepcopy(self.context_encoder)
        set_requires_grad(target_encoder, False)
        return target_encoder
    
#    def reset_moving_average(self):
#        del self.target_encoder
#        self.target_encoder = None
        
    def forward(self, x, context_mask, target_mask):
        context = self.context_encoder(x, context_mask)
        context_outputs = self.predictor(context, context_mask, target_mask)
        
        with torch.no_grad():
            target_encoder = self._get_target_encoder()
            target_outputs = target_encoder(x, target_mask)
        
        return context_outputs, target_outputs


In [11]:
# 下臂清零
#from collections import OrderedDict
#weights = torch.load("/data2/patho-vit_5_23_lung/output_jepa/gpvit_weight_5_5_epoch_48.pt")

#new_dict = OrderedDict()
#for k, v in weights.items():
#    if "module.target_encoder" not in k:
#        new_key = k[7:]
#        new_dict[new_key] = v

#jepa = Jepa()
#jepa.to(device)
#jepa.load_state_dict(new_dict, strict = False)
#jepa = nn.DataParallel(jepa)

In [12]:
from collections import OrderedDict
#weights = torch.load("output_jepa/gpvit_weight_4_13_epoch_16.pt")
weights = torch.load("/data2/patho-vit_5_23_ccim/up/13/11.20waibu/gpvit_weight_11_20_i_600.pt")
#new_dict = OrderedDict()
#for k, v in weights.items():
#    if "module.target_encoder" not in k:
#        new_key = k[7:]
#        new_dict[new_key] = v

jepa = Jepa()
jepa.to(device)
#jepa.load_state_dict(new_dict, strict = True)
jepa = nn.DataParallel(jepa)
jepa.load_state_dict(weights, strict = False)

<All keys matched successfully>

In [13]:
optimizer = torch.optim.AdamW(jepa.parameters(), lr= 1.5e-4, weight_decay = 0.2, betas = [0.9, 0.95])
#optimizer = torch.optim.AdamW(mae.parameters(), lr= 3e-4)
# 凯明原文，lr的计算方法为：基础lr（1.5e-4）* batch_size / 256。2月15日尝试使用，首次出现了每13轮的反弹，又改回3e-4

# 凯明原文，预热迭代40次，是指模型的学习速度从极低的值慢慢涨到指定的lr，理论上可避免训练初期，loss的严重振荡。
# 可能的代码如下。本次宫颈癌未用，下次有需要从头训练的瘤种，可以尝试。
# from transformers import get_linear_scheduler_with_warmup
# total_steps = len(train_loader) * epochs
# scheduler = get_linear_scheduler_with_warmup(optimizer, num_warmup_steps = len(train_loader) * 40, 
#                                              num_training_step = total_steps)

# 凯明使用了余弦退火，但没交待这里的T_0和T_mult是如何设置的
#scheduler = CosineAnnealingWarmRestarts(optimizer, T_0 = 40, T_mult = 1, eta_min= 0.75e-4)
#scheduler = CosineAnnealingLR(optimizer, T_max = 40)
#scheduler = StepLR(optimizer, step_size = 1, gamma = 0.7)

fconv = open(os.path.join("13/12.4/gpvit_convergence.csv"), "w")
fconv.write("epoch, metric, value\n")
fconv.close()

In [14]:
best_loss = 10e5
best_epoch_loss = 10e5
loss_curve = []
total_step = len(train_loader)
for epoch in range(epochs):
    epoch_loss = 0
    for i, (images, context_mask, target_mask) in enumerate(train_loader):
        #images = images[0][0].to(device)
        context_outputs, target_outputs = jepa(images[0], context_mask, target_mask) # 此处有标签，所以用[0]
        loss = loss_fn(context_outputs, target_outputs)
        loss *= 10e5   #6.20第1次
           
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        with torch.no_grad():
            #m = next(momentum_scheduler)
            m = 0.9
            for param_q, param_k in zip(jepa.module.context_encoder.parameters(), jepa.module.target_encoder.parameters()):
                param_k.data.mul_(m).add_((1.-m) * param_q.detach().data)
        
        epoch_loss += loss / total_step
    
        
        #if (i + 1) % 15000 == 0:
        #if (i + 1) % 1 == 0:
        #    print("Epoch [{}/{}], Step [{}/{}] Loss: {:.10f}"
       #          .format(epoch+1, epochs, i+1, total_step, loss))
            
        if (i + 1) % 200 == 0:
            if loss < best_loss:
                best_loss = loss
                torch.save(jepa.state_dict(), "13/12.4/gpvit_weight_12_4_i_{}.pt".format(i+1))
            fconv = open(os.path.join("13/12.4/gpvit_convergence.csv"), "a")
            fconv.write("{}, loss, {:.10f}\n".format(i+1, loss.item()))
            fconv.close()
    
    #scheduler.step()
    loss_curve.append(epoch_loss.cpu().detach())
    
    if (epoch+1) % 1 == 0:
    #if (epoch) % 1 == 0:
        if  epoch_loss < best_epoch_loss:
            best_epoch_loss = epoch_loss
            torch.save(jepa.state_dict(), "13/12.4/gpvit_weight_12_4_epoch_{}.pt".format(epoch+1))
        
        print("Epoch [{}/{}], Loss: {:.10f}"
             .format(epoch+1, epochs, epoch_loss.item()))
        
        fconv = open(os.path.join("13/12.4/gpvit_convergence.csv"), "a")
        fconv.write("{}, loss, {:.10f}\n".format(epoch+1, epoch_loss.item()))
        fconv.close()
    
    #if (epoch+1) >= 5:
    #if epoch % 1 == 0:
    


Epoch [1/10000], Loss: 10.8348598480
Epoch [2/10000], Loss: 10.8564672470
Epoch [3/10000], Loss: 9.9641628265
Epoch [4/10000], Loss: 9.8310670853
Epoch [5/10000], Loss: 11.4308090210
Epoch [6/10000], Loss: 11.1752147675
Epoch [7/10000], Loss: 10.2119522095
Epoch [8/10000], Loss: 10.1589555740
Epoch [9/10000], Loss: 10.3110685349
Epoch [10/10000], Loss: 10.4685993195
Epoch [11/10000], Loss: 9.6523704529
Epoch [12/10000], Loss: 9.8823804855
Epoch [13/10000], Loss: 9.9025526047
Epoch [14/10000], Loss: 9.7417573929
Epoch [15/10000], Loss: 9.3705539703
Epoch [16/10000], Loss: 10.1730194092
Epoch [17/10000], Loss: 9.2606277466
Epoch [18/10000], Loss: 8.8095960617
Epoch [19/10000], Loss: 9.4374237061
Epoch [20/10000], Loss: 9.3924455643


KeyboardInterrupt: 

In [15]:
#torch.save(jepa.state_dict(), "13/gpvit_weight_7_29_epoch_2.pt")

In [5]:
# 机场2，为每个补丁抽出特征，为整张切片制作特征图像

In [1]:
import os
from sklearn.metrics import balanced_accuracy_score
import gc

import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
import random
import glob

import torch
import torch.nn as nn
import torch.nn.functional as F

import torchvision.transforms as transforms

from torch.optim.lr_scheduler import StepLR
import torch.backends.cudnn as cudnn

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
path = "/data2/patho-vit_5_23_ccsurv"
# 此处patho_vit后的_5_23为包的版本号
# 若此jupyternotebook运行中kernel挂掉，重启后仅需运行此一代码块，然后跳到需要运行的代码块即可

In [2]:
from patho_vit.vit_luci2 import ViT as ViT
from patho_vit.airport2 import train_features_extractor_gpvit2 as Extractor
from collections import OrderedDict

In [3]:
vit = ViT(
    image_size = 384,
    patch_size = 16,
    dim = 768,
    depth = 12,
    heads = 12,
    mlp_dim = 768 * 4,
    num_classes = 1000
)
output, feature = vit(torch.randn(2, 3, 384, 384))
feature.shape

torch.Size([2, 1, 768])

In [4]:
# 导入定制vit，加载目标编码器的权重

#weights = torch.load("output_jepa/gpvit_weight_4_22_epoch_1.pt")
weights = torch.load("/data2/patho-vit_5_23_ccsurv/up/13/12.4/gpvit_weight_12_4_epoch_18.pt", map_location=torch.device('cpu'))
path = "/data2/patho-vit_5_23_ccsurv/up"
new_dict = OrderedDict()
for k, v in weights.items():
    if "module.target_encoder" in k:
        new_key = k[22:]
        new_dict[new_key] = v

vit = ViT(
    image_size = 384,
    patch_size = 16,
    dim = 768,
    depth = 12,
    heads = 12,
    mlp_dim = 768 * 4,
    num_classes = 1000
)
vit.to(device)
vit.load_state_dict(new_dict, strict = False)
vit = nn.DataParallel(vit)

In [10]:
# 一个文件夹一个文件夹的抽特征，并重构（384，384，768）的图像
transform = transforms.Compose([
    transforms.Resize([384, 384]),
    transforms.ToTensor(), 
    transforms.Normalize(mean=(0.485, 0.456, 0.406), 
                             std=(0.229, 0.224, 0.225))])

#train_lib = "files/store_5_train_{}.db".format((n+1) * 25)
train_lib_0 = torch.load("/data2/patho-vit_5_23_ccim/up/files_label/store_4_qlyy.db")
#train_lib_1 = torch.load("/data2/patho-vit_5_23_ccim/up/files_label/store_4_PD+WUQIAN.db")
#train_lib_2 = torch.load("/data2/patho-vit_5_23_ccim/up/files_label/store_4_im.db")
#train_lib_3 = torch.load("/data2/patho-vit_5_23_ccim/up/files_label/store_4_qlyywaibu.db")
extractor0 = Extractor(path = path, train_lib = train_lib_0, transform = transform, batch_size = 200, model = vit, device = device)

In [11]:
extractor1 = Extractor(path = path, train_lib = train_lib_1, transform = transform, batch_size = 200, model = vit, device = device)

In [12]:
extractor2 = Extractor(path = path, train_lib = train_lib_2, transform = transform, batch_size = 200, model = vit, device = device)

In [13]:
extractor3 = Extractor(path = path, train_lib = train_lib_3, transform = transform, batch_size = 200, model = vit, device = device)

In [None]:
from PIL import Image
image = Image.open("cat.jpg")
transform = transforms.Compose([
    transforms.Resize(384),
    transforms.ToTensor()
])
image2 = transform(image)
plt.imshow(image2.permute(1, 2, 0));

In [None]:
loss_curve = []

for epoch in range(4000):
    epoch_loss = 0
    
        #images = images[0][0].to(device)
    context_outputs, target_outputs = jepa(image2.unsqueeze(0), context_mask, target_mask) # 此处有标签，所以用[0]
    loss = loss_fn(context_outputs, target_outputs)
        
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
        
    with torch.no_grad():
        #m = next(momentum_scheduler)
        m = 0.9
        for param_q, param_k in zip(jepa.module.context_encoder.parameters(), jepa.module.target_encoder.parameters()):
            param_k.data.mul_(m).add_((1.-m) * param_q.detach().data)
        
    epoch_loss = loss 
    
    
    print("Epoch [{}/{}], Loss: {:.10f}"
                .format(epoch+1, 4000, loss))
            
        #    fconv = open(os.path.join("output/gpvit_convergence.csv"), "a")
        #    fconv.write("{}, loss, {:.4f}\n".format(epoch+1, loss.item()))
        #    fconv.close()
    
    #scheduler.step()
loss_curve.append(epoch_loss.cpu().detach())

In [None]:
from einops.layers.torch import Rearrange
rearrange = Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1 = 16, p2 = 16)
rearrange2 = Rearrange("b (h w) (p1 p2 c) -> b c (h p1) (w p2)", h = 24, p1 = 16, p2 = 16)
image_patches = rearrange(image2.unsqueeze(0))
image_patches.shape

In [None]:
target_mask_cache = []
for target_mask_iter in target_mask:
    target_mask_cache.append(target_mask_iter[0])

In [None]:
batch_range = torch.arange(1, device = device)[:, None]
context_indices = context_mask[0][0]

target1_indices = target_mask_cache[0]

target2_indices = target_mask_cache[1]

target3_indices = target_mask_cache[2]

target4_indices = target_mask_cache[3]

context_tokens = image_patches[batch_range, context_indices]
target1_tokens = image_patches[batch_range, target1_indices]
target2_tokens = image_patches[batch_range, target2_indices]
target3_tokens = image_patches[batch_range, target3_indices]
target4_tokens = image_patches[batch_range, target4_indices]

context_patches = torch.zeros(1, 576, 768, device = device)
target1_patches = copy.deepcopy(context_patches)
target2_patches = copy.deepcopy(context_patches)
target3_patches = copy.deepcopy(context_patches)
target4_patches = copy.deepcopy(context_patches)

context_patches[batch_range, context_indices] = context_tokens
target1_patches[batch_range, target1_indices] = target1_tokens
target2_patches[batch_range, target2_indices] = target2_tokens
target3_patches[batch_range, target3_indices] = target3_tokens
target4_patches[batch_range, target4_indices] = target4_tokens

context_patches = rearrange2(context_patches)
target1_patches = rearrange2(target1_patches)
target2_patches = rearrange2(target2_patches)
target3_patches = rearrange2(target3_patches)
target4_patches = rearrange2(target4_patches)

In [None]:
fig, axs = plt.subplots(1, 6)

axs[0].imshow(image2.permute(1, 2, 0));
axs[0].axis("off");

#axs[1].imshow(recon_image.cpu().detach().permute(1, 2, 0));
axs[1].imshow(context_patches.squeeze(0).permute(1, 2, 0));
axs[1].axis("off");
        
axs[2].imshow(target1_patches.squeeze(0).permute(1, 2, 0));
axs[2].axis("off");
        
axs[3].imshow(target2_patches.squeeze(0).permute(1, 2, 0));
axs[3].axis("off");

axs[4].imshow(target3_patches.squeeze(0).permute(1, 2, 0));
axs[4].axis("off");
        
axs[5].imshow(target4_patches.squeeze(0).permute(1, 2, 0));
axs[5].axis("off");