In [49]:

import torch
import torch.nn as nn

from torch.utils.data import Dataset,DataLoader
from torch.nn import functional as F
from torch.distributions.categorical import Categorical
import numpy as np
import matplotlib.pyplot as plt
import os
import sys

sys.path.append('C:/Users/Raphael/OneDrive/备份/桌面/大学生活和学习/3PI/毕业设计/code/perovskite_solar_cells/generative_method/transformer/model')
from cifdataloader import CifDataLoader
device =  torch.device('cuda')

# Parameters

In [50]:
VOCAB_SIZE = 1001
MAX_LEN = 159
EMBEDDING_DIM = 256
KEY_DIM = 256
N_HEADS = 2
FEED_FORWARD_DIM = 256
VALIDATION_SPLIT = 0.2
SEED = 42
LOAD_MODEL = False
BATCH_SIZE = 32
EPOCHS = 20
LAYER_TRANSFROMER = 4
REPEAT = 5  
JSON_DIR = 'C:/Users/Raphael/OneDrive/备份/桌面/大学生活和学习/3PI/毕业设计/code/perovskite_solar_cells/generative_method/transformer/datas/cif_datas'
ORDER = ['M','X','C','N','H']

# Load Dataset

In [51]:
cifdl = CifDataLoader(JSON_DIR, ORDER)

100%|██████████| 872/872 [00:09<00:00, 93.53it/s] 


In [52]:
class BHDataset(Dataset):
    def __init__(self, cifdl):
        self.datas = cifdl.direct_array

    def __len__(self):
        return 5*len(self.datas)

    def __getitem__(self, idx):
        data = torch.tensor(self.datas[idx%REPEAT]).to(device)
        return data[:-1], data[1:]

# 定义数据集

dataset = BHDataset(cifdl)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)

In [53]:
# 加载一些图像
data_iter = iter(dataloader)
data, label = next(data_iter)
data = data[1]
label = label[1]
print(data)
print(label)


tensor([ 50,  50,  50,  82,  17,  35,  35,  35,  35,  35,  35,  35,  35,  35,
         35,  53,   5, 524,  25, 455,  40, 681,  20, 543, 654, 448,   3,  20,
        517,   2, 901, 390,  27, 281, 767, 496, 293, 700, 726,  86, 232, 668,
        630, 192, 188, 696, 630, 247, 990, 180, 199, 989, 625, 239, 689, 702,
        722, 639, 237, 669,  86,  71, 546, 926, 422, 405, 301, 795, 107, 304,
        490, 570, 922,  40, 851, 925, 517, 587, 302,  87, 807, 301, 398, 380,
        925, 754, 126, 925, 451, 640, 425,  10, 743, 220, 476, 319, 971, 807,
        168, 851, 475, 316, 848, 805, 167, 971, 452, 641, 221,  15, 742, 426,
        602, 589, 303, 175, 811, 296, 264, 374, 925, 708, 124, 920, 494, 360,
        435,  53, 144, 213, 406, 610, 975, 789, 763, 837, 405, 610, 834, 789,
        763, 976, 492, 361, 212,  55, 143, 437, 265, 417, 302, 722, 101, 307,
        595, 573, 921, 155, 857], device='cuda:0')
tensor([ 50,  50,  82,  17,  35,  35,  35,  35,  35,  35,  35,  35,  35,  35,
         53, 

# 遮挡块

In [54]:
def causal_attention_mask(batch_size, n_dest, n_src, dtype):
    i = torch.arange(n_dest).unsqueeze(1)
    j = torch.arange(n_src)
    m = i >= j - n_src + n_dest
    mask = m.type(dtype)
    mask = mask.reshape(1, n_dest, n_src)
    return mask.repeat(batch_size, 1, 1)

mask = causal_attention_mask(1, 10, 10, torch.int32)
np.transpose(mask[0])

tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [0, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [0, 0, 1, 1, 1, 1, 1, 1, 1, 1],
        [0, 0, 0, 1, 1, 1, 1, 1, 1, 1],
        [0, 0, 0, 0, 1, 1, 1, 1, 1, 1],
        [0, 0, 0, 0, 0, 1, 1, 1, 1, 1],
        [0, 0, 0, 0, 0, 0, 1, 1, 1, 1],
        [0, 0, 0, 0, 0, 0, 0, 1, 1, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 1, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 1]], dtype=torch.int32)

# 定义Transformer块

In [55]:
class TransformerBlock(nn.Module):
    def __init__(self, num_heads, key_dim, embed_dim, ff_dim, dropout_rate=0.1):
        super(TransformerBlock, self).__init__()
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout_rate, batch_first=True)
        self.dropout_1 = nn.Dropout(dropout_rate)
        self.ln_1 = nn.LayerNorm(embed_dim, eps=1e-6)
        self.ffn = nn.Sequential(
            nn.Linear(embed_dim, ff_dim),
            nn.ReLU(),
            nn.Linear(ff_dim, embed_dim),
        )
        self.dropout_2 = nn.Dropout(dropout_rate)
        self.ln_2 = nn.LayerNorm(embed_dim, eps=1e-6)

    def forward(self, inputs):
        input_shape = inputs.shape
        batch_size = input_shape[0]
        seq_len = input_shape[1]
        causal_mask = torch.triu(torch.ones((seq_len, seq_len)), diagonal=1).bool().to(inputs.device)
        # causal_mask = causal_mask.repeat(batch_size*, 1, 1)
        attention_output, attention_scores = self.attn(
            inputs,
            inputs,
            inputs,
            attn_mask=causal_mask,
            need_weights=True,
        )
        attention_output = self.dropout_1(attention_output)
        out1 = self.ln_1(inputs + attention_output)
        ffn_output = self.ffn(out1)
        ffn_output = self.dropout_2(ffn_output)
        return self.ln_2(out1 + ffn_output), attention_scores

In [56]:
# # Create an instance of TransformerBlock
# transformer_block = TransformerBlock(num_heads=N_HEADS, key_dim=KEY_DIM, embed_dim=EMBEDDING_DIM, ff_dim=FEED_FORWARD_DIM)

# # Generate some random input data
# input_data = torch.randn(BATCH_SIZE, MAX_LEN, EMBEDDING_DIM)

# # Pass the input data through the TransformerBlock
# output, attention_scores = transformer_block(input_data)

# # Print the output and attention scores
# print("Output shape:", output.shape)
# print("Attention scores shape:", attention_scores.shape)


# 位置编码


In [57]:
class TokenAndPositionEmbedding(nn.Module):
    def __init__(self, max_len, vocab_size, embed_dim):
        super(TokenAndPositionEmbedding, self).__init__()
        self.token_emb = nn.Embedding(vocab_size, embed_dim)
        self.pos_emb = nn.Embedding(max_len, embed_dim)

    def forward(self, x):
        maxlen = x.size(1)
        positions = torch.arange(0, maxlen).unsqueeze(0).to(x.device)
        return self.token_emb(x) + self.pos_emb(positions)

In [58]:
# # Create an instance of TokenAndPositionEmbedding
# embedding = TokenAndPositionEmbedding(max_len=MAX_LEN, vocab_size=VOCAB_SIZE, embed_dim=EMBEDDING_DIM)

# # Generate some random input data
# input_data = torch.randint(0, VOCAB_SIZE, (BATCH_SIZE, MAX_LEN))

# # Pass the input data through the TokenAndPositionEmbedding module
# output = embedding(input_data)

# # Print the output
# print("Output shape:", output.shape)


# GPT模型

In [64]:
class GPT(nn.Module):
    def __init__(self, max_len, vocab_size, embed_dim, n_heads, key_dim, ff_dim, cifdl):
        super(GPT, self).__init__()
        self.token_and_pos_embedding = TokenAndPositionEmbedding(max_len, vocab_size, embed_dim)
        self.transformer_blocks = nn.ModuleList([TransformerBlock(n_heads, key_dim, embed_dim, ff_dim) for _ in range(LAYER_TRANSFROMER)])
        self.dense = nn.Linear(embed_dim, vocab_size)
        self.cifdl = cifdl

    def forward(self, x):
        x = self.token_and_pos_embedding(x)
        for transformer_block in self.transformer_blocks:
            x, attention_scores = transformer_block(x)
        outputs = F.log_softmax(self.dense(x), dim=-1)
        return outputs, attention_scores
    
    def generate(self, start_prompt, max_tokens, temperature):
        start_tokens = start_prompt.copy()
        sample_token = None
        info = []
        while len(start_tokens) < max_tokens:
            x = torch.tensor(start_tokens).unsqueeze(0).to(device)
            y, att = self(x)
            sample_token, probs = self.sample_from(y[0][-1], temperature)
            while sample_token > VOCAB_SIZE:
                sample_token, probs = self.sample_from(y[0][-1], temperature)
            info.append(
                {
                    "prompt": start_prompt,
                    "word_probs": probs,
                }
            )
            start_tokens.append(sample_token.item())
            if len(start_tokens) > 16:
                start_prompt.append(self.cifdl.bin_to_coordinate(sample_token.item()))
            else:
                start_prompt.append(sample_token.item())
        print(f"\ngenerated text:\n{start_prompt}\n")
        return info

    def sample_from(self, logits, temperature):
        probs = torch.nn.functional.softmax(logits / temperature, dim=-1)
        m = Categorical(probs)
        return m.sample(), probs

gpt = GPT(MAX_LEN, VOCAB_SIZE, EMBEDDING_DIM, N_HEADS, KEY_DIM, FEED_FORWARD_DIM,cifdl).to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(gpt.parameters())

In [60]:
# # Generate random input data
# test_input = torch.randint(0, VOCAB_SIZE, (BATCH_SIZE, MAX_LEN))

# # Pass the input data through the GPT model
# outputs, attention_scores = gpt(test_input)


In [61]:
# # 初始化模型
# gpt = GPT(MAX_LEN, VOCAB_SIZE, EMBEDDING_DIM, N_HEADS, KEY_DIM, FEED_FORWARD_DIM,cifdl).to(device)

# # 生成文本
# info = gpt.generate([0], 100, 1.0)

# # 打印生成的信息
# print(info)

In [65]:
def train_model(model, dataloader, optimizer, criterion, device, epochs):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for batch in dataloader:
            inputs, targets = batch
            inputs, targets = inputs.to(device), targets.to(device)

            optimizer.zero_grad()

            outputs, _ = model(inputs)
            loss = criterion(outputs.view(-1, outputs.size(-1)), targets.view(-1))

            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(dataloader)
        print(f'Epoch: {epoch+1}/{epochs}, Loss: {avg_loss:.4f}')

In [66]:
train_model(gpt, dataloader, optimizer, loss_fn, device, EPOCHS)

Epoch: 1/20, Loss: 0.4686
Epoch: 2/20, Loss: 0.0101
Epoch: 3/20, Loss: 0.0084
Epoch: 4/20, Loss: 0.0077
Epoch: 5/20, Loss: 0.0075
Epoch: 6/20, Loss: 0.0073
Epoch: 7/20, Loss: 0.0073
Epoch: 8/20, Loss: 0.0071
Epoch: 9/20, Loss: 0.0070
Epoch: 10/20, Loss: 0.0069
Epoch: 11/20, Loss: 0.0069
Epoch: 12/20, Loss: 0.0069
Epoch: 13/20, Loss: 0.0069
Epoch: 14/20, Loss: 0.0069
Epoch: 15/20, Loss: 0.4013
Epoch: 16/20, Loss: 0.0472
Epoch: 17/20, Loss: 0.0113
Epoch: 18/20, Loss: 0.0086
Epoch: 19/20, Loss: 0.0079
Epoch: 20/20, Loss: 0.0075


In [69]:
gpt.generate([1], 159, 1.0)


generated text:
[1, 50, 82, 82, 17, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 53, 4.0363323, 0.1369056, 6.291422346480002, 0.07977874265999998, 4.5275384, 5.917094069999999, 4.004426570700001, 0.0008064600000000001, 0.06489072, 0.0061432, 4.38511107, 0.08652095999999998, 4.363745462460001, 0.008871060000000002, 8.940285689999998, 3.7471689948, 0.09012600000000001, 2.94948471, 7.7246352, 4.22665686, 3.0473578199999998, 6.47909964, 6.9607479968, 0.3675, 2.3441214, 6.1414447999999995, 5.6596424, 1.73630838, 1.7008241400000002, 6.444999999999999, 5.6371554, 2.54474, 11.77014771, 1.6099999999999999, 1.77098616, 11.739939960000001, 5.622783200000001, 2.432557880760001, 6.3677937, 6.525874320000001, 6.8065224, 5.712608704960002, 2.432557880760001, 6.1686504, 0.3675, 0.27903516, 4.567030399999998, 9.03211725, 3.8508464999999994, 3.8125399999999994, 3.02923317, 7.895767199999999, 0.583604, 3.0336999999999996, 4.209721200000001, 4.985645600000001, 9.06111669, 0.1369056, 8.23463265, 9.06353331, 4.

[{'prompt': [1,
   50,
   82,
   82,
   17,
   35,
   35,
   35,
   35,
   35,
   35,
   35,
   35,
   35,
   35,
   53,
   4.0363323,
   0.1369056,
   6.291422346480002,
   0.07977874265999998,
   4.5275384,
   5.917094069999999,
   4.004426570700001,
   0.0008064600000000001,
   0.06489072,
   0.0061432,
   4.38511107,
   0.08652095999999998,
   4.363745462460001,
   0.008871060000000002,
   8.940285689999998,
   3.7471689948,
   0.09012600000000001,
   2.94948471,
   7.7246352,
   4.22665686,
   3.0473578199999998,
   6.47909964,
   6.9607479968,
   0.3675,
   2.3441214,
   6.1414447999999995,
   5.6596424,
   1.73630838,
   1.7008241400000002,
   6.444999999999999,
   5.6371554,
   2.54474,
   11.77014771,
   1.6099999999999999,
   1.77098616,
   11.739939960000001,
   5.622783200000001,
   2.432557880760001,
   6.3677937,
   6.525874320000001,
   6.8065224,
   5.712608704960002,
   2.432557880760001,
   6.1686504,
   0.3675,
   0.27903516,
   4.567030399999998,
   9.03211725,
   3