In [1]:
import torch

class MyNetwork(torch.nn.Module):
    def __init__(self):
        super(MyNetwork, self).__init__()
        self.conv1 = torch.nn.Conv2d(3, 6, 5)
        self.pool = torch.nn.MaxPool2d(2, 2)
        self.conv2 = torch.nn.Conv2d(6, 16, 5)
        self.fc1 = torch.nn.Linear(16 * 5 * 5, 120)
    
    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = torch.relu(self.fc1(x))
        return x

In [5]:
model = MyNetwork()
print(model)
print(model.parameters)
print(model.state_dict().keys())

MyNetwork(
  (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=400, out_features=120, bias=True)
)
<bound method Module.parameters of MyNetwork(
  (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=400, out_features=120, bias=True)
)>
odict_keys(['conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias'])


In [9]:
torch.save(model.state_dict(), 'model_weight.pt')
model_new = MyNetwork()
model_new.load_state_dict(torch.load('model_weight.pt', weights_only=True))
print(model_new.state_dict().keys())
print(model_new.conv1.weight[0])
print(model.conv1.weight[0])

odict_keys(['conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias'])
tensor([[[-0.0890,  0.0042, -0.0427,  0.0794, -0.0911],
         [-0.0134,  0.0673,  0.1109,  0.0731, -0.0212],
         [-0.0748, -0.0502, -0.0151, -0.0158,  0.0622],
         [ 0.0674,  0.0024,  0.0136,  0.0365, -0.1080],
         [-0.0154,  0.0232,  0.0612,  0.0337,  0.0293]],

        [[ 0.0887, -0.0952,  0.0437,  0.0708, -0.0881],
         [-0.0270, -0.0858,  0.1125, -0.0995,  0.0084],
         [-0.0762,  0.0590, -0.0690, -0.0739, -0.0135],
         [ 0.1025, -0.0979, -0.1078,  0.1074, -0.0003],
         [ 0.0081,  0.0468,  0.0784, -0.1064,  0.0047]],

        [[ 0.0773,  0.0785,  0.0373,  0.1052, -0.0265],
         [ 0.0410, -0.0077,  0.1055,  0.0623,  0.0079],
         [-0.1011, -0.0397,  0.0509, -0.0085,  0.0571],
         [ 0.0295, -0.0917, -0.0098, -0.0709,  0.0425],
         [ 0.0484,  0.0249, -0.0834, -0.0722, -0.0796]]],
       grad_fn=<SelectBackward0>)
tensor([[[-0.0890,  0

In [16]:
import transformers
model_id = '/Users/jingweixu/Downloads/Meta-Llama-3.1-8B-Instruct'
llama = transformers.LlamaForCausalLM.from_pretrained(model_id)
# llama.save_pretrained('/Users/jingweixu/Downloads/llama_test', from_pt=True)

Loading checkpoint shards: 100%|██████████| 4/4 [00:35<00:00,  8.97s/it]


In [17]:
from peft import get_peft_model, LoraConfig, TaskType
print(llama)
peft_config = LoraConfig(task_type=TaskType.CAUSAL_LM, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1)

peft_model = get_peft_model(llama, peft_config)
print(peft_model)

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
      )
    )
    (n